mirror of https://github.com/coqui-ai/TTS.git
styling
This commit is contained in:
parent
8b1014d188
commit
12722501bb
|
@ -1 +1 @@
|
|||
__version__ = '0.0.13.2'
|
||||
__version__ = "0.0.13.2"
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Extract Mel spectrograms with teacher forcing."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
|
@ -38,9 +39,7 @@ def setup_loader(ap, r, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
@ -60,19 +59,21 @@ def setup_loader(ap, r, verbose=False):
|
|||
)
|
||||
return loader
|
||||
|
||||
|
||||
def set_filename(wav_path, out_path):
|
||||
wav_file = os.path.basename(wav_path)
|
||||
file_name = wav_file.split('.')[0]
|
||||
file_name = wav_file.split(".")[0]
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
||||
wavq_path = os.path.join(out_path, "quant", file_name)
|
||||
mel_path = os.path.join(out_path, "mel", file_name)
|
||||
wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav')
|
||||
wav_path = os.path.join(out_path, "wav", file_name+'.wav')
|
||||
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
|
||||
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
|
||||
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
|
@ -123,10 +124,22 @@ def format_data(data):
|
|||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
|
||||
def inference(
|
||||
model_name,
|
||||
model,
|
||||
ap,
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask=None,
|
||||
speaker_ids=None,
|
||||
speaker_embeddings=None,
|
||||
):
|
||||
if model_name == "glow_tts":
|
||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
speaker_c = None
|
||||
if speaker_ids is not None:
|
||||
speaker_c = speaker_ids
|
||||
|
@ -140,7 +153,13 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
|
|||
|
||||
elif "tacotron" in model_name:
|
||||
_, postnet_outputs, *_ = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
mel_specs = []
|
||||
|
@ -154,7 +173,10 @@ def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_le
|
|||
model_output = postnet_outputs.detach().cpu().numpy()
|
||||
return model_output
|
||||
|
||||
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
|
||||
|
||||
def extract_spectrograms(
|
||||
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
|
@ -173,7 +195,18 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
|
|||
item_idx,
|
||||
) = format_data(data)
|
||||
|
||||
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
|
||||
model_output = inference(
|
||||
c.model.lower(),
|
||||
model,
|
||||
ap,
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
)
|
||||
|
||||
for idx in range(text_input.shape[0]):
|
||||
wav_file_path = item_idx[idx]
|
||||
|
@ -204,13 +237,14 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
|
|||
for data in export_metadata:
|
||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if "characters" in c.keys() and c['characters']:
|
||||
if "characters" in c.keys() and c["characters"]:
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# set model characters
|
||||
|
@ -242,37 +276,26 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
||||
own_loader = setup_loader(ap, r, verbose=True)
|
||||
|
||||
extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt")
|
||||
extract_spectrograms(
|
||||
own_loader,
|
||||
model,
|
||||
ap,
|
||||
args.output_path,
|
||||
quantized_wav=args.quantized,
|
||||
save_audio=args.save_audio,
|
||||
debug=args.debug,
|
||||
metada_name="metada.txt",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
help='Model file to be restored.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--output_path',
|
||||
type=str,
|
||||
help='Path to save mel specs',
|
||||
required=True)
|
||||
parser.add_argument('--debug',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Save audio files for debug')
|
||||
parser.add_argument('--save_audio',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Save audio files')
|
||||
parser.add_argument('--quantized',
|
||||
action='store_true',
|
||||
help='Save quantized audio files')
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
||||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# TODO: mixed precision training
|
||||
"""Trains GAN based vocoder model."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import itertools
|
||||
import traceback
|
||||
from inspect import signature
|
||||
|
||||
|
@ -496,8 +496,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||
|
||||
if c.discriminator_model == 'hifigan_discriminator':
|
||||
optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params)
|
||||
if c.discriminator_model == "hifigan_discriminator":
|
||||
optimizer_disc = optimizer_disc(
|
||||
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
|
||||
lr=c.lr_disc,
|
||||
**c.optimizer_params,
|
||||
)
|
||||
else:
|
||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from coqpit import MISSING, Coqpit, check_argument
|
||||
|
||||
|
||||
|
@ -151,11 +152,14 @@ class BaseDatasetConfig(Coqpit):
|
|||
Path to the file that lists the attention mask files used with models that require attention masks to
|
||||
train the duration predictor.
|
||||
"""
|
||||
name: str = ''
|
||||
path: str = ''
|
||||
meta_file_train: Union[str, List] = '' # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
|
||||
meta_file_val: str = ''
|
||||
meta_file_attn_mask: str = ''
|
||||
|
||||
name: str = ""
|
||||
path: str = ""
|
||||
meta_file_train: Union[
|
||||
str, List
|
||||
] = "" # TODO: don't take ignored speakers for multi-speaker datasets over this. This is Union for SC-Glow compat.
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
|
|
|
@ -59,7 +59,6 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
"""
|
||||
|
||||
|
||||
model: str = "glow_tts"
|
||||
|
||||
# model params
|
||||
|
|
|
@ -22,6 +22,7 @@ class GSTConfig(Coqpit):
|
|||
gst_num_style_tokens (int):
|
||||
Number of style token vectors. Defaults to 10.
|
||||
"""
|
||||
|
||||
gst_style_input_wav: str = None
|
||||
gst_style_input_weights: dict = None
|
||||
gst_embedding_dim: int = 256
|
||||
|
@ -130,7 +131,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
"""
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# phoneme settings
|
||||
|
|
|
@ -113,4 +113,4 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 200
|
||||
r: int = 1 #DO NOT CHANGE
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
|
|
|
@ -90,7 +90,6 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
|
|||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
|
||||
model: str = "multiband_melgan"
|
||||
|
||||
# Model specific params
|
||||
|
|
|
@ -66,6 +66,7 @@ class WavegradConfig(BaseVocoderConfig):
|
|||
lr_scheduler_params (dict):
|
||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
|
||||
"""
|
||||
|
||||
model: str = "wavegrad"
|
||||
# Model specific params
|
||||
generator_model: str = "wavegrad"
|
||||
|
|
|
@ -3,13 +3,9 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_input_path
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
@ -20,8 +16,8 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
def test_GlowTTS():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
|
@ -30,14 +26,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||
)
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||
|
||||
@staticmethod
|
||||
def test_Tacotron2():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
|
@ -46,14 +45,17 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||
)
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||
|
||||
@staticmethod
|
||||
def test_Tacotron():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth.tar")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
|
@ -62,5 +64,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"'
|
||||
)
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||
|
|
|
@ -130,6 +130,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class GlowTTSInferenceTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_inference():
|
||||
|
@ -174,13 +175,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
|||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
|
||||
# inference encoder and decoder with MAS
|
||||
y, *_ = model.inference_with_MAS(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
||||
)
|
||||
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||
|
||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths
|
||||
)
|
||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
|
||||
|
||||
assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y.shape, y_dec.shape
|
||||
)
|
||||
assert (
|
||||
y_dec.shape == y.shape
|
||||
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y.shape, y_dec.shape
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue