diff --git a/.gitignore b/.gitignore index 64d1f0d5..7e9da0d8 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,8 @@ core recipes/WIP/* recipes/ljspeech/LJSpeech-1.1/* recipes/vctk/VCTK/* +recipes/**/*.npy +recipes/**/*.json VCTK-Corpus-removed-silence/* # ignore training logs @@ -161,4 +163,5 @@ speakers.json internal/* *_pitch.npy *_phoneme.npy -wandb \ No newline at end of file +wandb +depot/* \ No newline at end of file diff --git a/TTS/.models.json b/TTS/.models.json index 6853e18e..985fea37 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -1,5 +1,17 @@ { "tts_models": { + "multilingual":{ + "multi-dataset":{ + "your_tts":{ + "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", + "default_vocoder": null, + "commit": "e9a1953e", + "license": "CC BY-NC-ND 4.0", + "contact": "egolge@coqui.ai" + } + } + }, "en": { "ek1": { "tacotron2": { diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 0af98ff1..7b489fd6 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -12,7 +12,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.models import setup_model -from TTS.tts.utils.speakers import get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters @@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=False, verbose=verbose, - speaker_id_mapping=speaker_manager.speaker_ids, - d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None, + speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None, ) if c.use_phonemes and c.compute_input_seq_cache: @@ -234,8 +234,13 @@ def main(args): # pylint: disable=redefined-outer-name # use eval and training partitions meta_data = meta_data_train + meta_data_eval - # parse speakers - speaker_manager = get_speaker_manager(c, args, meta_data_train) + # init speaker manager + if c.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=meta_data) + elif c.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) + else: + speaker_manager = None # setup model model = setup_model(c) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py new file mode 100644 index 00000000..d3143ca3 --- /dev/null +++ b/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,62 @@ +"""Find all the unique characters in a dataset""" +import argparse +import multiprocessing +from argparse import RawTextHelpFormatter + +from tqdm.contrib.concurrent import process_map + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.text import text2phone + + +def compute_phonemes(item): + try: + text = item[0] + language = item[-1] + ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|") + except: + return [] + return list(set(ph)) + + +def main(): + # pylint: disable=W0601 + global c + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + items = train_items + eval_items + print("Num items:", len(items)) + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py new file mode 100755 index 00000000..9070f2da --- /dev/null +++ b/TTS/bin/remove_silence_using_vad.py @@ -0,0 +1,89 @@ +import argparse +import glob +import multiprocessing +import os +import pathlib + +from tqdm.contrib.concurrent import process_map + +from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave + + +def remove_silence(filepath): + output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists + if os.path.exists(output_path) and not args.force: + return + + # create all directory structure + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) + # load wave + audio, sample_rate = read_wave(filepath) + + # get speech segments + segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness) + + segments = list(segments) + num_segments = len(segments) + flag = False + # create the output wave + if num_segments != 0: + for i, segment in reversed(list(enumerate(segments))): + if i >= 1: + if not flag: + concat_segment = segment + flag = True + else: + concat_segment = segment + concat_segment + else: + if flag: + segment = segment + concat_segment + # print("Saving: ", output_path) + write_wave(output_path, segment, sample_rate) + return + else: + print("> Just Copying the file to:", output_path) + # if fail to remove silence just write the file + write_wave(output_path, audio, sample_rate) + return + + +def preprocess_audios(): + files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) + print("> Number of files: ", len(files)) + if not args.force: + print("> Ignoring files that already exist in the output directory.") + + if files: + # create threads + num_threads = multiprocessing.cpu_count() + process_map(remove_silence, files, max_workers=num_threads, chunksize=15) + else: + print("> No files Found !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2" + ) + parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir") + parser.add_argument( + "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir" + ) + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-a", + "--aggressiveness", + type=int, + default=2, + help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.", + ) + args = parser.parse_args() + preprocess_audios() diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index bf7de798..509b3da6 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -152,12 +152,19 @@ If you don't specify any models, then it uses LJSpeech based English model. # args for multi-speaker synthesis parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) parser.add_argument( "--speaker_idx", type=str, help="Target speaker ID for a multi-speaker TTS model.", default=None, ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) parser.add_argument( "--speaker_wav", nargs="+", @@ -173,6 +180,14 @@ If you don't specify any models, then it uses LJSpeech based English model. const=True, default=False, ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) # aux args parser.add_argument( "--save_spectogram", @@ -184,7 +199,7 @@ If you don't specify any models, then it uses LJSpeech based English model. args = parser.parse_args() # print the description if either text or list_models is not set - if args.text is None and not args.list_models and not args.list_speaker_idxs: + if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs: parser.parse_args(["-h"]) # load model manager @@ -194,6 +209,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = None config_path = None speakers_file_path = None + language_ids_file_path = None vocoder_path = None vocoder_config_path = None encoder_path = None @@ -217,6 +233,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = args.model_path config_path = args.config_path speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path if args.vocoder_path is not None: vocoder_path = args.vocoder_path @@ -231,6 +248,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path, config_path, speakers_file_path, + language_ids_file_path, vocoder_path, vocoder_config_path, encoder_path, @@ -246,6 +264,14 @@ If you don't specify any models, then it uses LJSpeech based English model. print(synthesizer.tts_model.speaker_manager.speaker_ids) return + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + print( + " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + print(synthesizer.tts_model.language_manager.language_id_mapping) + return + # check the arguments against a multi-speaker model. if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): print( @@ -258,7 +284,7 @@ If you don't specify any models, then it uses LJSpeech based English model. print(" > Text: {}".format(args.text)) # kick it - wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style) + wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index ad6d95f7..8c364300 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model +from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model from TTS.speaker_encoder.utils.training import init_training from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples @@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_eval ap = AudioProcessor(**c.audio) - model = setup_model(c) + model = setup_speaker_encoder_model(c) optimizer = RAdam(model.parameters(), lr=c.lr) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index e28e9dec..3360a940 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,9 +1,10 @@ import os -from TTS.config import load_config, register_config +from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -45,15 +46,32 @@ def main(): ap = AudioProcessor(**config.audio) # init speaker manager - if config.use_speaker_embedding: + if check_config_and_model_args(config, "use_speaker_embedding", True): speaker_manager = SpeakerManager(data_items=train_samples + eval_samples) - elif config.use_d_vector_file: - speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + if hasattr(config, "model_args"): + config.model_args.num_speakers = speaker_manager.num_speakers + else: + config.num_speakers = speaker_manager.num_speakers + elif check_config_and_model_args(config, "use_d_vector_file", True): + speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) + if hasattr(config, "model_args"): + config.model_args.num_speakers = speaker_manager.num_speakers + else: + config.num_speakers = speaker_manager.num_speakers else: speaker_manager = None + if hasattr(config, "use_language_embedding") and config.use_language_embedding: + language_manager = LanguageManager(config=config) + if hasattr(config, "model_args"): + config.model_args.num_languages = language_manager.num_languages + else: + config.num_languages = language_manager.num_languages + else: + language_manager = None + # init the model from config - model = setup_model(config, speaker_manager) + model = setup_model(config, speaker_manager, language_manager) # init the trainer and 🚀 trainer = Trainer( diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index f626163f..5c905295 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -95,3 +95,38 @@ def load_config(config_path: str) -> None: config = config_class() config.from_dict(config_dict) return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + return False + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index d91bf2b6..9e9d4692 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -60,6 +60,12 @@ class BaseAudioConfig(Coqpit): trim_db (int): Silence threshold used for silence trimming. Defaults to 45. + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + power (float): Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the artifacts in the synthesized voice. Defaults to 1.5. @@ -116,6 +122,9 @@ class BaseAudioConfig(Coqpit): # silence trimming do_trim_silence: bool = True trim_db: int = 45 + # rms volume normalization + do_rms_norm: bool = False + db_level: float = None # griffin-lim params power: float = 1.5 griffin_lim_iters: int = 60 @@ -198,7 +207,8 @@ class BaseDatasetConfig(Coqpit): name: str = "" path: str = "" meta_file_train: str = "" - ununsed_speakers: List[str] = None + ignored_speakers: List[str] = None + language: str = "" meta_file_val: str = "" meta_file_attn_mask: str = "" @@ -335,6 +345,8 @@ class BaseTrainingConfig(Coqpit): num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False + use_language_weighted_sampler: bool = False + # paths output_path: str = None # distributed diff --git a/TTS/server/server.py b/TTS/server/server.py index c6d67141..f2512582 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -100,7 +100,15 @@ if args.vocoder_path is not None: # load models synthesizer = Synthesizer( - model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint="", + encoder_config="", + use_cuda=args.use_cuda, ) use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1 @@ -165,7 +173,7 @@ def tts(): style_wav = style_wav_uri_to_dict(style_wav) print(" > Model input: {}".format(text)) - wavs = synthesizer.tts(text, speaker_idx=speaker_idx, style_wav=style_wav) + wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) return send_file(out, mimetype="audio/wav") diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 6b2b0dd4..5b0fee22 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset): feats = torch.stack(feats) labels = torch.stack(labels) - return feats.transpose(1, 2), labels + return feats, labels diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index de5bb007..7ac08514 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -1,7 +1,9 @@ import numpy as np import torch +import torchaudio from torch import nn +from TTS.speaker_encoder.models.resnet import PreEmphasis from TTS.utils.io import load_fsspec @@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module): class LSTMSpeakerEncoder(nn.Module): - def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): super().__init__() self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + layers = [] # choise LSTM layer if use_lstm_with_projection: @@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module): else: self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + else: + self.torch_spec = None + self._init_layers() def _init_layers(self): @@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module): elif "weight" in name: nn.init.xavier_normal_(param) - def forward(self, x): - # TODO: implement state passing for lstms + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) d = self.layers(x) if self.use_lstm_with_projection: - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) - else: + d = d[:, -1] + if l2_norm: d = torch.nn.functional.normalize(d, p=2, dim=1) return d @torch.no_grad() - def inference(self, x): - d = self.layers.forward(x) - if self.use_lstm_with_projection: - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) - else: - d = torch.nn.functional.normalize(d, p=2, dim=1) + def inference(self, x, l2_norm=True): + d = self.forward(x, l2_norm=l2_norm) return d def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index fcc850d7..643449c8 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,10 +1,25 @@ import numpy as np import torch +import torchaudio from torch import nn +# from TTS.utils.audio import TorchSTFT from TTS.utils.io import load_fsspec +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + class SELayer(nn.Module): def __init__(self, channel, reduction=8): super(SELayer, self).__init__() @@ -70,12 +85,17 @@ class ResNetSpeakerEncoder(nn.Module): num_filters=[32, 64, 128, 256], encoder_type="ASP", log_input=False, + use_torch_spec=False, + audio_config=None, ): super(ResNetSpeakerEncoder, self).__init__() self.encoder_type = encoder_type self.input_dim = input_dim self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU(inplace=True) self.bn1 = nn.BatchNorm2d(num_filters[0]) @@ -88,6 +108,36 @@ class ResNetSpeakerEncoder(nn.Module): self.instancenorm = nn.InstanceNorm1d(input_dim) + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + else: + self.torch_spec = None + outmap_size = int(self.input_dim / 8) self.attention = nn.Sequential( @@ -140,9 +190,23 @@ class ResNetSpeakerEncoder(nn.Module): return out def forward(self, x, l2_norm=False): - x = x.transpose(1, 2) + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + if self.log_input: x = (x + 1e-6).log() x = self.instancenorm(x).unsqueeze(1) @@ -175,11 +239,19 @@ class ResNetSpeakerEncoder(nn.Module): return x @torch.no_grad() - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + def inference(self, x, l2_norm=False): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): """ Generate embeddings for a batch of utterances x: 1xTxD """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + max_len = x.shape[1] if max_len < num_frames: @@ -195,11 +267,10 @@ class ResNetSpeakerEncoder(nn.Module): frames_batch.append(frames) frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.forward(frames_batch, l2_norm=True) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) if return_mean: embeddings = torch.mean(embeddings, dim=0, keepdim=True) - return embeddings def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 1981fbe9..b8aa4093 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -170,16 +170,24 @@ def to_camel(text): return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) -def setup_model(c): - if c.model_params["model_name"].lower() == "lstm": +def setup_speaker_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": model = LSTMSpeakerEncoder( - c.model_params["input_dim"], - c.model_params["proj_dim"], - c.model_params["lstm_dim"], - c.model_params["num_lstm_layers"], + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + elif config.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder( + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, ) - elif c.model_params["model_name"].lower() == "resnet": - model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) return model diff --git a/TTS/trainer.py b/TTS/trainer.py index 2a2cfc46..7bffb386 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -202,7 +202,7 @@ class Trainer: os.makedirs(output_path, exist_ok=True) # copy training assets to the output folder - copy_model_files(config, output_path, new_fields=None) + copy_model_files(config, output_path) # init class members self.args = args @@ -439,7 +439,7 @@ class Trainer: if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: print(" > Restoring Scaler...") scaler = _restore_list_objs(checkpoint["scaler"], scaler) - except (KeyError, RuntimeError): + except (KeyError, RuntimeError, ValueError): print(" > Partial model initialization...") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], config) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index d490e6e6..36c948af 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -82,8 +82,14 @@ class VitsConfig(BaseTTSConfig): add_blank (bool): If true, a blank token is added in between every character. Defaults to `True`. - test_sentences (List[str]): - List of sentences to be used for testing. + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. Note: Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. @@ -117,6 +123,7 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 # data loader params return_wav: bool = True @@ -130,13 +137,13 @@ class VitsConfig(BaseTTSConfig): add_blank: bool = True # testing - test_sentences: List[str] = field( + test_sentences: List[List] = field( default_factory=lambda: [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], ] ) @@ -146,29 +153,15 @@ class VitsConfig(BaseTTSConfig): use_speaker_embedding: bool = False speakers_file: str = None speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False # use d-vectors use_d_vector_file: bool = False - d_vector_file: str = False + d_vector_file: str = None d_vector_dim: int = None def __post_init__(self): - # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. - if self.num_speakers > 0: - self.model_args.num_speakers = self.num_speakers - - # speaker embedding settings - if self.use_speaker_embedding: - self.model_args.use_speaker_embedding = True - if self.speakers_file: - self.model_args.speakers_file = self.speakers_file - if self.speaker_embedding_channels: - self.model_args.speaker_embedding_channels = self.speaker_embedding_channels - - # d-vector settings - if self.use_d_vector_file: - self.model_args.use_d_vector_file = True - if self.d_vector_dim is not None and self.d_vector_dim > 0: - self.model_args.d_vector_dim = self.d_vector_dim - if self.d_vector_file: - self.model_args.d_vector_file = self.d_vector_file + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 4fae974f..40eed7e3 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -67,16 +67,22 @@ def load_tts_samples( root_path = dataset["path"] meta_file_train = dataset["meta_file_train"] meta_file_val = dataset["meta_file_val"] + ignored_speakers = dataset["ignored_speakers"] + language = dataset["language"] + # setup the right data processor if formatter is None: formatter = _get_formatter_by_name(name) # load train set - meta_data_train = formatter(root_path, meta_file_train) + meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) + meta_data_train = [[*item, language] for item in meta_data_train] + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: - meta_data_eval = formatter(root_path, meta_file_val) + meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) + meta_data_eval = [[*item, language] for item in meta_data_eval] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 04314bab..843cea58 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -37,6 +37,7 @@ class TTSDataset(Dataset): enable_eos_bos: bool = False, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, + language_id_mapping: Dict = None, use_noise_augment: bool = False, verbose: bool = False, ): @@ -122,7 +123,9 @@ class TTSDataset(Dataset): self.enable_eos_bos = enable_eos_bos self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment + self.verbose = verbose self.input_seq_computed = False self.rescue_item_idx = 1 @@ -197,10 +200,10 @@ class TTSDataset(Dataset): def load_data(self, idx): item = self.items[idx] - if len(item) == 4: - text, wav_file, speaker_name, attn_file = item + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item else: - text, wav_file, speaker_name = item + text, wav_file, speaker_name, language_name = item attn = None raw_text = text @@ -218,7 +221,7 @@ class TTSDataset(Dataset): self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, - self.phoneme_language, + language_name if language_name else self.phoneme_language, self.custom_symbols, self.characters, self.add_blank, @@ -260,6 +263,7 @@ class TTSDataset(Dataset): "attn": attn, "item_idx": self.items[idx][1], "speaker_name": speaker_name, + "language_name": language_name, "wav_file_name": os.path.basename(wav_file), } return sample @@ -269,6 +273,7 @@ class TTSDataset(Dataset): item = args[0] func_args = args[1] text, wav_file, *_ = item + func_args[3] = item[3] phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) return phonemes @@ -335,7 +340,6 @@ class TTSDataset(Dataset): else: lengths = np.array([len(ins[0]) for ins in self.items]) - # sort items based on the sequence length in ascending order idxs = np.argsort(lengths) new_items = [] ignored = [] @@ -345,10 +349,7 @@ class TTSDataset(Dataset): ignored.append(idx) else: new_items.append(self.items[idx]) - # shuffle batch groups - # create batches with similar length items - # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size @@ -356,14 +357,8 @@ class TTSDataset(Dataset): temp_items = new_items[offset:end_offset] random.shuffle(temp_items) new_items[offset:end_offset] = temp_items - - if len(new_items) == 0: - raise RuntimeError(" [!] No items left after filtering.") - - # update items to the new sorted items self.items = new_items - # logging if self.verbose: print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths))) @@ -413,9 +408,14 @@ class TTSDataset(Dataset): # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} + # get language ids from language names + if self.language_id_mapping is not None: + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] + else: + language_ids = None # get pre-computed d-vectors if self.d_vector_mapping is not None: - wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing] + wav_files_names = list(batch["wav_file_name"]) d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] else: d_vectors = None @@ -466,6 +466,9 @@ class TTSDataset(Dataset): if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + # compute linear spectrogram if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] @@ -528,6 +531,7 @@ class TTSDataset(Dataset): "waveform": wav_padded, "raw_text": batch["raw_text"], "pitch": pitch, + "language_ids": language_ids, } raise TypeError( @@ -542,7 +546,6 @@ class TTSDataset(Dataset): class PitchExtractor: """Pitch Extractor for computing F0 from wav files. - Args: items (List[List]): Dataset samples. verbose (bool): Whether to print the progress. diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 425eb0cd..1f23f85e 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -12,7 +12,7 @@ from tqdm import tqdm ######################## -def tweb(root_path, meta_file): +def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalize TWEB dataset. https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset """ @@ -28,7 +28,7 @@ def tweb(root_path, meta_file): return items -def mozilla(root_path, meta_file): +def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] @@ -43,7 +43,7 @@ def mozilla(root_path, meta_file): return items -def mozilla_de(root_path, meta_file): +def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] @@ -59,7 +59,7 @@ def mozilla_de(root_path, meta_file): return items -def mailabs(root_path, meta_files=None): +def mailabs(root_path, meta_files=None, ignored_speakers=None): """Normalizes M-AI-Labs meta data files to TTS format Args: @@ -68,25 +68,34 @@ def mailabs(root_path, meta_files=None): recursively. Defaults to None """ speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") - if meta_files is None: + if not meta_files: csv_files = glob(root_path + "/**/metadata.csv", recursive=True) else: csv_files = meta_files + # meta_files = [f.strip() for f in meta_files.split(",")] items = [] for csv_file in csv_files: - txt_file = os.path.join(root_path, csv_file) + if os.path.isfile(csv_file): + txt_file = csv_file + else: + txt_file = os.path.join(root_path, csv_file) + folder = os.path.dirname(txt_file) # determine speaker based on folder structure... speaker_name_match = speaker_regex.search(txt_file) if speaker_name_match is None: continue speaker_name = speaker_name_match.group("speaker_name") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue print(" | > {}".format(csv_file)) with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") - if meta_files is None: + if not meta_files: wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") else: wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") @@ -94,11 +103,12 @@ def mailabs(root_path, meta_files=None): text = cols[1].strip() items.append([text, wav_file, speaker_name]) else: - raise RuntimeError("> File %s does not exist!" % (wav_file)) + # M-AI-Labs have some missing samples, so just print the warning + print("> File %s does not exist!" % (wav_file)) return items -def ljspeech(root_path, meta_file): +def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the LJSpeech meta data file to TTS format https://keithito.com/LJ-Speech-Dataset/""" txt_file = os.path.join(root_path, meta_file) @@ -113,7 +123,7 @@ def ljspeech(root_path, meta_file): return items -def ljspeech_test(root_path, meta_file): +def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the LJSpeech meta data file for TTS testing https://keithito.com/LJ-Speech-Dataset/""" txt_file = os.path.join(root_path, meta_file) @@ -127,7 +137,7 @@ def ljspeech_test(root_path, meta_file): return items -def sam_accenture(root_path, meta_file): +def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the sam-accenture meta data file to TTS format https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) @@ -144,12 +154,12 @@ def sam_accenture(root_path, meta_file): return items -def ruslan(root_path, meta_file): +def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the RUSLAN meta data file to TTS format https://ruslan-corpus.github.io/""" txt_file = os.path.join(root_path, meta_file) items = [] - speaker_name = "ljspeech" + speaker_name = "ruslan" with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") @@ -159,11 +169,11 @@ def ruslan(root_path, meta_file): return items -def css10(root_path, meta_file): +def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the CSS10 dataset file to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] - speaker_name = "ljspeech" + speaker_name = "css10" with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") @@ -173,7 +183,7 @@ def css10(root_path, meta_file): return items -def nancy(root_path, meta_file): +def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Normalizes the Nancy meta data file to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] @@ -187,7 +197,7 @@ def nancy(root_path, meta_file): return items -def common_voice(root_path, meta_file): +def common_voice(root_path, meta_file, ignored_speakers=None): """Normalize the common voice meta data file to TTS format.""" txt_file = os.path.join(root_path, meta_file) items = [] @@ -198,15 +208,19 @@ def common_voice(root_path, meta_file): cols = line.split("\t") text = cols[2] speaker_name = cols[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) items.append([text, wav_file, "MCV_" + speaker_name]) return items -def libri_tts(root_path, meta_files=None): +def libri_tts(root_path, meta_files=None, ignored_speakers=None): """https://ai.google/tools/datasets/libri-tts/""" items = [] - if meta_files is None: + if not meta_files: meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) else: if isinstance(meta_files, str): @@ -222,13 +236,17 @@ def libri_tts(root_path, meta_files=None): _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") wav_file = os.path.join(_root_path, file_name + ".wav") text = cols[2] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue items.append([text, wav_file, "LTTS_" + speaker_name]) for item in items: assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" return items -def custom_turkish(root_path, meta_file): +def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "turkish-female" @@ -247,7 +265,7 @@ def custom_turkish(root_path, meta_file): # ToDo: add the dataset link when the dataset is released publicly -def brspeech(root_path, meta_file): +def brspeech(root_path, meta_file, ignored_speakers=None): """BRSpeech 3.0 beta""" txt_file = os.path.join(root_path, meta_file) items = [] @@ -258,21 +276,25 @@ def brspeech(root_path, meta_file): cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[2] - speaker_name = cols[3] - items.append([text, wav_file, speaker_name]) + speaker_id = cols[3] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append([text, wav_file, speaker_id]) return items -def vctk(root_path, meta_files=None, wavs_path="wav48"): +def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" - test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) file_id = txt_file.split(".")[0] - if isinstance(test_speakers, list): # if is list ignore this speakers ids - if speaker_id in test_speakers: + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: continue with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] @@ -282,15 +304,16 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"): return items -def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): +def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" items = [] txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for text_file in txt_files: _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) file_id = txt_file.split(".")[0] - if isinstance(meta_files, list): # if is list ignore this speakers ids - if speaker_id in meta_files: + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: continue wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") items.append([None, wav_file, "VCTK_" + speaker_id]) @@ -298,7 +321,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): return items -def mls(root_path, meta_files=None): +def mls(root_path, meta_files=None, ignored_speakers=None): """http://www.openslr.org/94/""" items = [] with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: @@ -307,19 +330,23 @@ def mls(root_path, meta_files=None): text = text[:-1] speaker, book, *_ = file.split("_") wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker in ignored_speakers: + continue items.append([text, wav_file, "MLS_" + speaker]) return items # ======================================== VOX CELEB =========================================== -def voxceleb2(root_path, meta_file=None): +def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument """ :param meta_file Used only for consistency with load_tts_samples api """ return _voxcel_x(root_path, meta_file, voxcel_idx="2") -def voxceleb1(root_path, meta_file=None): +def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument """ :param meta_file Used only for consistency with load_tts_samples api """ @@ -361,7 +388,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): return [x.strip().split("|") for x in f.readlines()] -def baker(root_path: str, meta_file: str) -> List[List[str]]: +def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument """Normalizes the Baker meta data file to TTS format Args: @@ -381,7 +408,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: return items -def kokoro(root_path, meta_file): +def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" txt_file = os.path.join(root_path, meta_file) items = [] diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index 2c0303be..e766ed6a 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -18,8 +18,13 @@ class DurationPredictor(nn.Module): dropout_p (float): Dropout rate used after each conv layer. """ - def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None): + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + # class arguments self.in_channels = in_channels self.filter_channels = hidden_channels @@ -36,7 +41,10 @@ class DurationPredictor(nn.Module): if cond_channels is not None and cond_channels != 0: self.cond = nn.Conv1d(cond_channels, in_channels, 1) - def forward(self, x, x_mask, g=None): + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1) + + def forward(self, x, x_mask, g=None, lang_emb=None): """ Shapes: - x: :math:`[B, C, T]` @@ -45,6 +53,10 @@ class DurationPredictor(nn.Module): """ if g is not None: x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ea342e8..7de45041 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module): self.feat_loss_alpha = c.feat_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha + self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, c.audio.hop_length, @@ -585,6 +586,11 @@ class VitsGeneratorLoss(nn.Module): l = kl / torch.sum(z_mask) return l + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + return l + def forward( self, waveform, @@ -598,6 +604,9 @@ class VitsGeneratorLoss(nn.Module): feats_disc_fake, feats_disc_real, loss_duration, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None, ): """ Shapes: @@ -618,13 +627,20 @@ class VitsGeneratorLoss(nn.Module): # compute mel spectrograms from the waveforms mel = self.stft(waveform) mel_hat = self.stft(waveform_hat) + # compute losses + loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration + + if use_speaker_encoder_as_loss: + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha + loss += loss_se + return_dict["loss_spk_encoder"] = loss_se + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index cfc8b6ac..ef426ace 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -37,6 +37,7 @@ class TextEncoder(nn.Module): num_layers: int, kernel_size: int, dropout_p: float, + language_emb_dim: int = None, ): """Text Encoder for VITS model. @@ -55,8 +56,12 @@ class TextEncoder(nn.Module): self.hidden_channels = hidden_channels self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + if language_emb_dim: + hidden_channels += language_emb_dim + self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -72,13 +77,18 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths): + def forward(self, x, x_lengths, lang_emb=None): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index 91e53da3..120d0944 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -178,10 +178,21 @@ class StochasticDurationPredictor(nn.Module): """ def __init__( - self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0 + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=0, ): super().__init__() + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + # condition encoder text self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) @@ -205,7 +216,10 @@ class StochasticDurationPredictor(nn.Module): if cond_channels != 0 and cond_channels is not None: self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) - def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0): + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -217,6 +231,10 @@ class StochasticDurationPredictor(nn.Module): x = self.pre(x) if g is not None: x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + x = self.convs(x, x_mask) x = self.proj(x) * x_mask diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 780f22cd..4cc8b658 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model(config, speaker_manager: "SpeakerManager" = None): +def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: @@ -31,7 +31,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None): config.model_params.num_chars = num_chars if "model_args" in config: config.model_args.num_chars = num_chars - model = MyModel(config, speaker_manager=speaker_manager) + if config.model.lower() in ["vits"]: # If model supports multiple languages + model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) + else: + model = MyModel(config, speaker_manager=speaker_manager) return model diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 854526de..e52cd765 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -73,9 +74,18 @@ class BaseTTS(BaseModel): def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit): - """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding - vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. Args: config (Coqpit): Model configuration. @@ -97,6 +107,57 @@ class BaseTTS(BaseModel): self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_setences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. @@ -122,6 +183,7 @@ class BaseTTS(BaseModel): attn_mask = batch["attns"] waveform = batch["waveform"] pitch = batch["pitch"] + language_ids = batch["language_ids"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -169,6 +231,7 @@ class BaseTTS(BaseModel): "item_idx": item_idx, "waveform": waveform, "pitch": pitch, + "language_ids": language_ids, } def get_data_loader( @@ -188,8 +251,15 @@ class BaseTTS(BaseModel): # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: - speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None @@ -199,7 +269,14 @@ class BaseTTS(BaseModel): if hasattr(self, "make_symbols"): custom_symbols = self.make_symbols(self.config) - # init dataset + if hasattr(self, "language_manager"): + language_id_mapping = ( + self.language_manager.language_id_mapping if self.args.use_language_embedding else None + ) + else: + language_id_mapping = None + + # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, @@ -222,7 +299,8 @@ class BaseTTS(BaseModel): use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + d_vector_mapping=d_vector_mapping, + language_id_mapping=language_id_mapping, ) # pre-compute phonemes @@ -268,7 +346,22 @@ class BaseTTS(BaseModel): # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None - # init dataloader + # Weighted samplers + assert not ( + num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) + ), "language_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) + ), "speaker_weighted_sampler is not supported with DistributedSampler" + + if sampler is None: + if getattr(config, "use_language_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_language_weighted_sampler(dataset.items) + elif getattr(config, "use_speaker_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_speaker_weighted_sampler(dataset.items) + loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, @@ -340,8 +433,7 @@ class BaseTTS(BaseModel): return test_figures, test_audios def on_init_start(self, trainer): - """Save the speaker.json at the beginning of the training. And update the config.json with the - speakers.json file path.""" + """Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.json") self.speaker_manager.save_speaker_ids_to_file(output_path) @@ -352,3 +444,13 @@ class BaseTTS(BaseModel): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `speakers.json` is saved to {output_path}.") print(" > `speakers_file` is updated in the config.json.") + + if hasattr(self, "language_manager") and self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_language_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index bc459b7f..8b09fdf9 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,13 +1,15 @@ import math -import random from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Tuple import torch + +# import torchaudio from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator @@ -15,6 +17,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment @@ -138,11 +141,50 @@ class VitsArgs(Coqpit): use_d_vector_file (bool): Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + d_vector_dim (int): Number of d-vector channels. Defaults to 0. detach_dp_input (bool): Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. """ num_chars: int = 100 @@ -179,11 +221,23 @@ class VitsArgs(Coqpit): use_speaker_embedding: bool = False num_speakers: int = 0 speakers_file: str = None + d_vector_file: str = None speaker_embedding_channels: int = 256 use_d_vector_file: bool = False - d_vector_file: str = None d_vector_dim: int = 0 detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False class Vits(BaseTTS): @@ -216,13 +270,18 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: Coqpit, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): super().__init__(config) self.END2END = True - self.speaker_manager = speaker_manager + self.language_manager = language_manager if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig if "num_chars" not in config: @@ -242,6 +301,7 @@ class Vits(BaseTTS): self.args = args self.init_multispeaker(config) + self.init_multilingual(config) self.length_scale = args.length_scale self.noise_scale = args.noise_scale @@ -260,6 +320,7 @@ class Vits(BaseTTS): args.num_layers_text_encoder, args.kernel_size_text_encoder, args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -289,10 +350,16 @@ class Vits(BaseTTS): args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim + args.hidden_channels, + 256, + 3, + args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( @@ -318,54 +385,158 @@ class Vits(BaseTTS): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer or with external `d_vectors` computed from a speaker encoder model. + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + Args: config (Coqpit): Model configuration. data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ self.embedded_speaker_dim = 0 - if hasattr(config, "model_args"): - config = config.model_args + self.num_speakers = self.args.num_speakers - self.num_speakers = config.num_speakers + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers - if config.use_speaker_embedding: - self._init_speaker_embedding(config) + if self.args.use_speaker_embedding: + self._init_speaker_embedding() - if config.use_d_vector_file: - self._init_d_vector(config) + if self.args.use_d_vector_file: + self._init_d_vector() - def _init_speaker_embedding(self, config): + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.speaker_encoder is None and ( + not config.speaker_encoder_model_path or not config.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.speaker_encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.speaker_encoder, "audio_config") + and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] + ): + # TODO: change this with torchaudio Resample + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format( + self.config.audio["sample_rate"], + self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) + ) + # pylint: disable=W0101,W0105 + """ self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.audio_config["sample_rate"], + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) + else: + self.audio_transform = None + """ + + def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init - if config.speakers_file is not None: - self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) - if self.num_speakers > 0: print(" > initialization of speaker-embedding layers.") - self.embedded_speaker_dim = config.speaker_embedding_channels + self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) - def _init_d_vector(self, config): + def _init_d_vector(self): # pylint: disable=attribute-defined-outside-init if hasattr(self, "emb_g"): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") - self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) - self.embedded_speaker_dim = config.d_vector_dim + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + self.emb_l = None @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g = None, None + sid, g, lid = None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: sid = sid.unsqueeze_(0) if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: - g = aux_input["d_vectors"] - return sid, g + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + return sid, g, lid def get_aux_input(self, aux_input: Dict): - sid, g = self._set_cond_input(aux_input) - return {"speaker_id": sid, "style_wav": None, "d_vector": g} + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } def forward( self, @@ -373,7 +544,8 @@ class Vits(BaseTTS): x_lengths: torch.tensor, y: torch.tensor, y_lengths: torch.tensor, - aux_input={"d_vectors": None, "speaker_ids": None}, + waveform: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, ) -> Dict: """Forward pass of the model. @@ -382,7 +554,9 @@ class Vits(BaseTTS): x_lengths (torch.tensor): Batch of input character sequence lengths. y (torch.tensor): Batch of input spectrograms. y_lengths (torch.tensor): Batch of input spectrogram lengths. - aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. Returns: Dict: model outputs keyed by the output name. @@ -392,17 +566,24 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` + - waveform: :math:`[B, T_wav, 1]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` """ outputs = {} - sid, g = self._set_cond_input(aux_input) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) - + sid, g, lid = self._set_cond_input(aux_input) # speaker embedding - if self.num_speakers > 1 and sid is not None: + if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -428,6 +609,7 @@ class Vits(BaseTTS): x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: @@ -436,6 +618,7 @@ class Vits(BaseTTS): x.detach() if self.args.detach_dp_input else x, x_mask, g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration @@ -447,40 +630,73 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + """if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch)""" + + pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + outputs.update( { "model_outputs": o, "alignments": attn.squeeze(1), - "slice_ids": slice_ids, "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, } ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): """ Shapes: - x: :math:`[B, T_seq]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` """ - sid, g = self._set_cond_input(aux_input) + sid, g, lid = self._set_cond_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) - - if self.num_speakers > 0 and sid is not None: + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + if self.args.use_sdp: - logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp) + logw = self.duration_predictor( + x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + ) else: - logw = self.duration_predictor(x, x_mask, g=g) + logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) @@ -499,12 +715,30 @@ class Vits(BaseTTS): outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} return outputs - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): - """TODO: create an end-point for voice conversion""" + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ assert self.num_speakers > 0, "num_speakers have to be larger than 0." - g_src = self.emb_g(sid_src).unsqueeze(-1) - g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src) + + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(speaker_cond_src).unsqueeze(-1) + g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1) + elif self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src) z_p = self.flow(z, y_mask, g=g_src) z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) @@ -525,6 +759,30 @@ class Vits(BaseTTS): if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -532,6 +790,7 @@ class Vits(BaseTTS): linear_input = batch["linear_input"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] waveform = batch["waveform"] # generator pass @@ -540,31 +799,26 @@ class Vits(BaseTTS): text_lengths, linear_input.transpose(1, 2), mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + waveform.transpose(1, 2), + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) # cache tensors for the discriminator self.y_disc_cache = None self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] - wav_seg = segment( - waveform.transpose(1, 2), - outputs["slice_ids"] * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, - ) - self.wav_seg_disc_cache = wav_seg - outputs["waveform_seg"] = wav_seg + self.wav_seg_disc_cache = outputs["waveform_seg"] # compute discriminator scores and features outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( - outputs["model_outputs"], wav_seg + outputs["model_outputs"], outputs["waveform_seg"] ) # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), - waveform=wav_seg.float(), + waveform=outputs["waveform_seg"].float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), @@ -574,6 +828,9 @@ class Vits(BaseTTS): feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=outputs["gt_spk_emb"], + syn_spk_emb=outputs["syn_spk_emb"], ) elif optimizer_idx == 1: @@ -651,32 +908,28 @@ class Vits(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = { - "speaker_id": None - if not self.config.use_speaker_embedding - else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), - "d_vector": None - if not self.config.use_d_vector_file - else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1), - "style_wav": None, - } - for idx, sen in enumerate(test_sentences): - wav, alignment, _, _ = synthesis( - self, - sen, - self.config, - "cuda" in str(next(self.parameters()).device), - ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ).values() - - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + for idx, s_info in enumerate(test_sentences): + try: + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + language_name=aux_inputs["language_name"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + except: # pylint: disable=bare-except + print(" !! Error creating Test Sentence -", idx) return test_figures, test_audios def get_optimizer(self) -> List: @@ -695,8 +948,12 @@ class Vits(BaseTTS): self.waveform_decoder.parameters(), ) # add the speaker embedding layer - if hasattr(self, "emb_g"): + if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file: gen_parameters = chain(gen_parameters, self.emb_g.parameters()) + # add the language embedding layer + if hasattr(self, "emb_l") and self.args.use_language_embedding: + gen_parameters = chain(gen_parameters, self.emb_l.parameters()) + optimizer0 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) @@ -769,6 +1026,10 @@ class Vits(BaseTTS): ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py new file mode 100644 index 00000000..fc7eec57 --- /dev/null +++ b/TTS/tts/utils/languages.py @@ -0,0 +1,122 @@ +import json +import os +from typing import Dict, List + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit +from torch.utils.data.sampler import WeightedRandomSampler + + +class LanguageManager: + """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by language. + + Args: + language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by + TTS models. Defaults to "". + config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed. + Defaults to None. + + Examples: + >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path) + >>> language_id_mapper = manager.language_ids + """ + + language_id_mapping: Dict = {} + + def __init__( + self, + language_ids_file_path: str = "", + config: Coqpit = None, + ): + self.language_id_mapping = {} + if language_ids_file_path: + self.set_language_ids_from_file(language_ids_file_path) + + if config: + self.set_language_ids_from_config(config) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + @property + def num_languages(self) -> int: + return len(list(self.language_id_mapping.keys())) + + @property + def language_names(self) -> List: + return list(self.language_id_mapping.keys()) + + @staticmethod + def parse_language_ids_from_config(c: Coqpit) -> Dict: + """Set language id from config. + + Args: + c (Coqpit): Config + + Returns: + Tuple[Dict, int]: Language ID mapping and the number of languages. + """ + languages = set({}) + for dataset in c.datasets: + if "language" in dataset: + languages.add(dataset["language"]) + else: + raise ValueError(f"Dataset {dataset['name']} has no language specified.") + return {name: i for i, name in enumerate(sorted(list(languages)))} + + def set_language_ids_from_config(self, c: Coqpit) -> None: + """Set language IDs from config samples. + + Args: + items (List): Data sampled returned by `load_meta_data()`. + """ + self.language_id_mapping = self.parse_language_ids_from_config(c) + + def set_language_ids_from_file(self, file_path: str) -> None: + """Load language ids from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.language_id_mapping = self._load_json(file_path) + + def save_language_ids_to_file(self, file_path: str) -> None: + """Save language IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.language_id_mapping) + + +def _set_file_path(path): + """Find the language_ids.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "language_ids.json") + path_continue = os.path.join(path, "language_ids.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + return None + + +def get_language_weighted_sampler(items: list): + language_names = np.array([item[3] for item in items]) + unique_language_names = np.unique(language_names).tolist() + language_ids = [unique_language_names.index(l) for l in language_names] + language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) + weight_language = 1.0 / language_count + dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() + return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 13696a20..07076d90 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -7,9 +7,10 @@ import fsspec import numpy as np import torch from coqpit import Coqpit +from torch.utils.data.sampler import WeightedRandomSampler from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_model +from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.utils.audio import AudioProcessor @@ -161,8 +162,10 @@ class SpeakerManager: file_path (str): Path to the target json file. """ self.d_vectors = self._load_json(file_path) + speakers = sorted({x["name"] for x in self.d_vectors.values()}) self.speaker_ids = {name: i for i, name in enumerate(speakers)} + self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys()))) def get_d_vector_by_clip(self, clip_idx: str) -> List: @@ -209,6 +212,32 @@ class SpeakerManager: d_vectors = np.stack(d_vectors[:num_samples]).mean(0) return d_vectors + def get_random_speaker_id(self) -> Any: + """Get a random d_vector. + + Args: + + Returns: + np.ndarray: d_vector. + """ + if self.speaker_ids: + return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]] + + return None + + def get_random_d_vector(self) -> Any: + """Get a random D ID. + + Args: + + Returns: + np.ndarray: d_vector. + """ + if self.d_vectors: + return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"] + + return None + def get_speakers(self) -> List: return self.speaker_ids @@ -223,18 +252,15 @@ class SpeakerManager: config_path (str): Model config file path. """ self.speaker_encoder_config = load_config(config_path) - self.speaker_encoder = setup_model(self.speaker_encoder_config) + self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config) self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) - # normalize the input audio level and trim silences - # self.speaker_encoder_ap.do_sound_norm = True - # self.speaker_encoder_ap.do_trim_silence = True - def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list: + def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list: """Compute a d_vector from a given audio file. Args: - wav_file (Union[str, list]): Target file path. + wav_file (Union[str, List[str]]): Target file path. Returns: list: Computed d_vector. @@ -242,12 +268,16 @@ class SpeakerManager: def _compute(wav_file: str): waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) - spec = self.speaker_encoder_ap.melspectrogram(waveform) - spec = torch.from_numpy(spec.T) + if not self.speaker_encoder_config.model_params.get("use_torch_spec", False): + m_input = self.speaker_encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + if self.use_cuda: - spec = spec.cuda() - spec = spec.unsqueeze(0) - d_vector = self.speaker_encoder.compute_embedding(spec) + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + d_vector = self.speaker_encoder.compute_embedding(m_input) return d_vector if isinstance(wav_file, list): @@ -364,11 +394,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: # new speaker manager with speaker IDs file. speaker_manager.set_speaker_ids_from_file(c.speakers_file) - print( - " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + + if speaker_manager.num_speakers > 0: + print( + " > Speaker manager is loaded with {} speakers: {}".format( + speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + ) ) - ) + # save file if path is defined if out_path: out_file_path = os.path.join(out_path, "speakers.json") @@ -378,3 +411,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, else: speaker_manager.save_speaker_ids_to_file(out_file_path) return speaker_manager + + +def get_speaker_weighted_sampler(items: list): + speaker_names = np.array([item[2] for item in items]) + unique_speaker_names = np.unique(speaker_names).tolist() + speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] + speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) + weight_speaker = 1.0 / speaker_count + dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() + return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 578c26c0..24b747be 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed: import tensorflow as tf -def text_to_seq(text, CONFIG, custom_symbols=None): +def text_to_seq(text, CONFIG, custom_symbols=None, language=None): text_cleaner = [CONFIG.text_cleaner] # text ot phonemes to sequence vector if CONFIG.use_phonemes: @@ -23,7 +23,7 @@ def text_to_seq(text, CONFIG, custom_symbols=None): phoneme_to_sequence( text, text_cleaner, - CONFIG.phoneme_language, + language if language else CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters, add_blank=CONFIG.add_blank, @@ -71,6 +71,7 @@ def run_model_torch( speaker_id: int = None, style_mel: torch.Tensor = None, d_vector: torch.Tensor = None, + language_id: torch.Tensor = None, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -96,6 +97,7 @@ def run_model_torch( "speaker_ids": speaker_id, "d_vectors": d_vector, "style_mel": style_mel, + "language_ids": language_id, }, ) return outputs @@ -160,19 +162,20 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def speaker_id_to_torch(speaker_id, cuda=False): - if speaker_id is not None: - speaker_id = np.asarray(speaker_id) - speaker_id = torch.from_numpy(speaker_id) +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) if cuda: - return speaker_id.cuda() - return speaker_id + return aux_id.cuda() + return aux_id def embedding_to_torch(d_vector, cuda=False): if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) if cuda: return d_vector.cuda() return d_vector @@ -208,6 +211,8 @@ def synthesis( use_griffin_lim=False, do_trim_silence=False, d_vector=None, + language_id=None, + language_name=None, backend="torch", ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to @@ -244,6 +249,12 @@ def synthesis( d_vector (torch.Tensor): d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + language_id (int): + Language ID passed to the language embedding layer in multi-langual model. Defaults to None. + + language_name (str): + Language name corresponding to the language code used by the phonemizer. Defaults to None. + backend (str): tf or torch. Defaults to "torch". """ @@ -258,15 +269,18 @@ def synthesis( if hasattr(model, "make_symbols"): custom_symbols = model.make_symbols(CONFIG) # preprocess the given text - text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols) + text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name) # pass tensors to backend if backend == "torch": if speaker_id is not None: - speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) if d_vector is not None: d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + if language_id is not None: + language_id = id_to_torch(language_id, cuda=use_cuda) + if not isinstance(style_mel, dict): style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) @@ -278,7 +292,7 @@ def synthesis( text_inputs = tf.expand_dims(text_inputs, 0) # synthesize voice if backend == "torch": - outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector) + outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() alignments = outputs["alignments"] diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 4b041ed8..f3ffa478 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -135,3 +135,12 @@ def phoneme_cleaners(text): text = remove_aux_symbols(text) text = collapse_whitespace(text) return text + + +def multilingual_cleaners(text): + """Pipeline for multilingual text""" + text = lowercase(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index e64b95e0..25f93c34 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -16,6 +16,60 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. TODO: Merge this with audio.py + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". """ def __init__( @@ -32,6 +86,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method use_mel=False, do_amp_to_db=False, spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", ): super().__init__() self.n_fft = n_fft @@ -45,6 +102,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method self.use_mel = use_mel self.do_amp_to_db = do_amp_to_db self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None if use_mel: @@ -83,6 +143,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method M = o[:, :, :, 0] P = o[:, :, :, 1] S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + + if self.power is not None: + S = S ** self.power + if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) if self.do_amp_to_db: @@ -91,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method def _build_mel_basis(self): mel_basis = librosa.filters.mel( - self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, ) self.mel_basis = torch.from_numpy(mel_basis).float() @@ -167,7 +237,7 @@ class AudioProcessor(object): minimum filter frequency for computing melspectrograms. Defaults to None. mel_fmax (int, optional): - maximum filter frequency for computing melspectrograms.. Defaults to None. + maximum filter frequency for computing melspectrograms. Defaults to None. spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20. @@ -196,6 +266,12 @@ class AudioProcessor(object): do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + stats_path (str, optional): Path to the computed stats file. Defaults to None. @@ -233,6 +309,8 @@ class AudioProcessor(object): do_sound_norm=False, do_amp_to_db_linear=True, do_amp_to_db_mel=True, + do_rms_norm=False, + db_level=None, stats_path=None, verbose=True, **_, @@ -264,6 +342,8 @@ class AudioProcessor(object): self.do_sound_norm = do_sound_norm self.do_amp_to_db_linear = do_amp_to_db_linear self.do_amp_to_db_mel = do_amp_to_db_mel + self.do_rms_norm = do_rms_norm + self.db_level = db_level self.stats_path = stats_path # setup exp_func for db to amp conversion if log_func == "np.log": @@ -656,21 +736,6 @@ class AudioProcessor(object): frame_period=1000 * self.hop_length / self.sample_rate, ) f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # pad = int((self.win_length / self.hop_length) / 2) - # f0 = [0.0] * pad + f0 + [0.0] * pad - # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) - # f0 = np.array(f0, dtype=np.float32) - - # f01, _, _ = librosa.pyin( - # x, - # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - # fmax=self.mel_fmax, - # frame_length=self.win_length, - # sr=self.sample_rate, - # fill_na=0.0, - # ) - - # spec = self.melspectrogram(x) return f0 ### Audio Processing ### @@ -713,10 +778,33 @@ class AudioProcessor(object): """ return x / abs(x).max() * 0.95 + @staticmethod + def _rms_norm(wav, db_level=-27): + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2)) + return wav * a + + def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = self.db_level + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = self._rms_norm(x, db_level) + return wav + ### save and load ### def load_wav(self, filename: str, sr: int = None) -> np.ndarray: """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + Args: filename (str): Path to the wav file. sr (int, optional): Sampling rate for resampling. Defaults to None. @@ -725,8 +813,10 @@ class AudioProcessor(object): np.ndarray: Loaded waveform. """ if self.resample: + # loading with resampling. It is significantly slower. x, sr = librosa.load(filename, sr=self.sample_rate) elif sr is None: + # SF is faster than librosa for loading files x, sr = sf.read(filename) assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) else: @@ -738,6 +828,8 @@ class AudioProcessor(object): print(f" [!] File cannot be trimmed for silence - {filename}") if self.do_sound_norm: x = self.sound_norm(x) + if self.do_rms_norm: + x = self.rms_volume_norm(x, self.db_level) return x def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: diff --git a/TTS/utils/io.py b/TTS/utils/io.py index a93f6118..54818ce9 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -26,7 +26,7 @@ class AttrDict(dict): self.__dict__ = self -def copy_model_files(config: Coqpit, out_path, new_fields): +def copy_model_files(config: Coqpit, out_path, new_fields=None): """Copy config.json and other model files to training folder and add new fields. diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index cfbbdff0..b002da53 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -46,36 +46,66 @@ class ModelManager(object): with open(file_path, "r", encoding="utf-8") as json_file: self.models_dict = json.load(json_file) - def list_langs(self): - print(" Name format: type/language") - for model_type in self.models_dict: - for lang in self.models_dict[model_type]: - print(f" >: {model_type}/{lang} ") + def _list_models(self, model_type, model_count=0): + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") + else: + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list - def list_datasets(self): - print(" Name format: type/language/dataset") - for model_type in self.models_dict: - for lang in self.models_dict[model_type]: - for dataset in self.models_dict[model_type][lang]: - print(f" >: {model_type}/{lang}/{dataset}") + def _list_for_model_type(self, model_type): + print(" Name format: language/dataset/model") + models_name_list = [] + model_count = 1 + model_type = "tts_models" + models_name_list.extend(self._list_models(model_type, model_count)) + return [name.replace(model_type + "/", "") for name in models_name_list] def list_models(self): print(" Name format: type/language/dataset/model") models_name_list = [] model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_langs(self): + """Print all the available languages""" + print(" Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + print(f" >: {model_type}/{lang} ") + + def list_datasets(self): + """Print all the datasets""" + print(" Name format: type/language/dataset") for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: - for model in self.models_dict[model_type][lang][dataset]: - model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - output_path = os.path.join(self.output_prefix, model_full_name) - if os.path.exists(output_path): - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") - else: - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") - models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") - model_count += 1 - return models_name_list + print(f" >: {model_type}/{lang}/{dataset}") def download_model(self, model_name): """Download model files given the full model name. @@ -121,6 +151,8 @@ class ModelManager(object): output_stats_path = os.path.join(output_path, "scale_stats.npy") output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar") # update the scale_path.npy file path in the model config.json self._update_path("audio.stats_path", output_stats_path, config_path) @@ -133,6 +165,12 @@ class ModelManager(object): self._update_path("speakers_file", output_speaker_ids_file_path, config_path) self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + @staticmethod def _update_path(field_name, new_path, config_path): """Update the path in the model config.json for the current environment after download""" @@ -159,8 +197,12 @@ class ModelManager(object): # download the file r = requests.get(file_url) # extract the file - with zipfile.ZipFile(io.BytesIO(r.content)) as z: - z.extractall(output_folder) + try: + with zipfile.ZipFile(io.BytesIO(r.content)) as z: + z.extractall(output_folder) + except zipfile.BadZipFile: + print(f" > Error: Bad zip file - {file_url}") + raise zipfile.BadZipFile # move the files to the outer path for file_path in z.namelist()[1:]: src_path = os.path.join(output_folder, file_path) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 043c4982..d1d978d8 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,12 +1,13 @@ import time -from typing import List +from typing import List, Union import numpy as np import pysbd import torch -from TTS.config import load_config +from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import @@ -23,6 +24,7 @@ class Synthesizer(object): tts_checkpoint: str, tts_config_path: str, tts_speakers_file: str = "", + tts_languages_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", encoder_checkpoint: str = "", @@ -52,6 +54,7 @@ class Synthesizer(object): self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config self.encoder_checkpoint = encoder_checkpoint @@ -63,6 +66,9 @@ class Synthesizer(object): self.speaker_manager = None self.num_speakers = 0 self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} self.d_vector_dim = 0 self.seg = self._get_segmenter("en") self.use_cuda = use_cuda @@ -110,29 +116,93 @@ class Synthesizer(object): self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) speaker_manager = self._init_speaker_manager() + language_manager = self._init_language_manager() + self._set_speaker_encoder_paths_from_tts_config() + speaker_manager = self._init_speaker_encoder(speaker_manager) - self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) + if language_manager is not None: + self.tts_model = setup_tts_model( + config=self.tts_config, + speaker_manager=speaker_manager, + language_manager=language_manager, + ) + else: + self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr( + self.tts_config.model_args, "speaker_encoder_config_path" + ): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + + def _is_use_speaker_embedding(self): + """Check if the speaker embedding is used in the model""" + # we handle here the case that some models use model_args some don't + use_speaker_embedding = False + if hasattr(self.tts_config, "model_args"): + use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False) + use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False) + return use_speaker_embedding + + def _is_use_d_vector_file(self): + """Check if the d-vector file is used in the model""" + # we handle here the case that some models use model_args some don't + use_d_vector_file = False + if hasattr(self.tts_config, "model_args"): + config = self.tts_config.model_args + use_d_vector_file = config.get("use_d_vector_file", False) + config = self.tts_config + use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False) + return use_d_vector_file + def _init_speaker_manager(self): """Initialize the SpeakerManager""" # setup if multi-speaker settings are in the global model config speaker_manager = None - if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True: + speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None) + if self._is_use_speaker_embedding(): if self.tts_speakers_file: speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) - if self.tts_config.get("speakers_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) + if speakers_file: + speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file) - if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True: + if self._is_use_d_vector_file(): + d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None) if self.tts_speakers_file: speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) - if self.tts_config.get("d_vector_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) + if d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file) return speaker_manager + def _init_speaker_encoder(self, speaker_manager): + """Initialize the SpeakerEncoder""" + if self.encoder_checkpoint: + if speaker_manager is None: + speaker_manager = SpeakerManager( + encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config + ) + else: + speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) + return speaker_manager + + def _init_language_manager(self): + """Initialize the LanguageManager""" + # setup if multi-lingual settings are in the global model config + language_manager = None + if check_config_and_model_args(self.tts_config, "use_language_embedding", True): + if self.tts_languages_file: + language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) + elif self.tts_config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file) + else: + language_manager = LanguageManager(config=self.tts_config) + return language_manager + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. @@ -174,13 +244,21 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path, self.output_sample_rate) - def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]: + def tts( + self, + text: str, + speaker_name: str = "", + language_name: str = "", + speaker_wav: Union[str, List[str]] = None, + style_wav=None, + ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. Args: text (str): input text. - speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "". - speaker_wav (): + speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". + language_name (str, optional): language id for multi-language models. Defaults to "". + speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None. style_wav ([type], optional): style waveform for GST. Defaults to None. Returns: @@ -196,29 +274,49 @@ class Synthesizer(object): speaker_embedding = None speaker_id = None if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"): - if speaker_idx and isinstance(speaker_idx, str): + if speaker_name and isinstance(speaker_name, str): if self.tts_config.use_d_vector_file: # get the speaker embedding from the saved d_vectors. - speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0] + speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0] speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] else: # get speaker idx from the speaker name - speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx] + speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name] - elif not speaker_idx and not speaker_wav: + elif not speaker_name and not speaker_wav: raise ValueError( " [!] Look like you use a multi-speaker model. " - "You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model." + "You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model." ) else: speaker_embedding = None else: - if speaker_idx: + if speaker_name: raise ValueError( - f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}." + f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " ) + # handle multi-lingaul + language_id = None + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None + ): + if language_name and isinstance(language_name, str): + language_id = self.tts_model.language_manager.language_id_mapping[language_name] + + elif not language_name: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_name}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + # compute a new d_vector from the given clip. if speaker_wav is not None: speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav) @@ -234,6 +332,8 @@ class Synthesizer(object): use_cuda=self.use_cuda, ap=self.ap, speaker_id=speaker_id, + language_id=language_id, + language_name=language_name, style_wav=style_wav, enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, use_griffin_lim=use_gl, diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py new file mode 100644 index 00000000..923544d0 --- /dev/null +++ b/TTS/utils/vad.py @@ -0,0 +1,144 @@ +# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py +import collections +import contextlib +import wave + +import webrtcvad + + +def read_wave(path): + """Reads a .wav file. + + Takes the path, and returns (PCM audio data, sample rate). + """ + with contextlib.closing(wave.open(path, "rb")) as wf: + num_channels = wf.getnchannels() + assert num_channels == 1 + sample_width = wf.getsampwidth() + assert sample_width == 2 + sample_rate = wf.getframerate() + assert sample_rate in (8000, 16000, 32000, 48000) + pcm_data = wf.readframes(wf.getnframes()) + return pcm_data, sample_rate + + +def write_wave(path, audio, sample_rate): + """Writes a .wav file. + + Takes path, PCM audio data, and sample rate. + """ + with contextlib.closing(wave.open(path, "wb")) as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio) + + +class Frame(object): + """Represents a "frame" of audio data.""" + + def __init__(self, _bytes, timestamp, duration): + self.bytes = _bytes + self.timestamp = timestamp + self.duration = duration + + +def frame_generator(frame_duration_ms, audio, sample_rate): + """Generates audio frames from PCM audio data. + + Takes the desired frame duration in milliseconds, the PCM data, and + the sample rate. + + Yields Frames of the requested duration. + """ + n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) + offset = 0 + timestamp = 0.0 + duration = (float(n) / sample_rate) / 2.0 + while offset + n < len(audio): + yield Frame(audio[offset : offset + n], timestamp, duration) + timestamp += duration + offset += n + + +def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): + """Filters out non-voiced audio frames. + + Given a webrtcvad.Vad and a source of audio frames, yields only + the voiced audio. + + Uses a padded, sliding window algorithm over the audio frames. + When more than 90% of the frames in the window are voiced (as + reported by the VAD), the collector triggers and begins yielding + audio frames. Then the collector waits until 90% of the frames in + the window are unvoiced to detrigger. + + The window is padded at the front and back to provide a small + amount of silence or the beginnings/endings of speech around the + voiced frames. + + Arguments: + + sample_rate - The audio sample rate, in Hz. + frame_duration_ms - The frame duration in milliseconds. + padding_duration_ms - The amount to pad the window, in milliseconds. + vad - An instance of webrtcvad.Vad. + frames - a source of audio frames (sequence or generator). + + Returns: A generator that yields PCM audio data. + """ + num_padding_frames = int(padding_duration_ms / frame_duration_ms) + # We use a deque for our sliding window/ring buffer. + ring_buffer = collections.deque(maxlen=num_padding_frames) + # We have two states: TRIGGERED and NOTTRIGGERED. We start in the + # NOTTRIGGERED state. + triggered = False + + voiced_frames = [] + for frame in frames: + is_speech = vad.is_speech(frame.bytes, sample_rate) + + # sys.stdout.write('1' if is_speech else '0') + if not triggered: + ring_buffer.append((frame, is_speech)) + num_voiced = len([f for f, speech in ring_buffer if speech]) + # If we're NOTTRIGGERED and more than 90% of the frames in + # the ring buffer are voiced frames, then enter the + # TRIGGERED state. + if num_voiced > 0.9 * ring_buffer.maxlen: + triggered = True + # sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,)) + # We want to yield all the audio we see from now until + # we are NOTTRIGGERED, but we have to start with the + # audio that's already in the ring buffer. + for f, _ in ring_buffer: + voiced_frames.append(f) + ring_buffer.clear() + else: + # We're in the TRIGGERED state, so collect the audio data + # and add it to the ring buffer. + voiced_frames.append(frame) + ring_buffer.append((frame, is_speech)) + num_unvoiced = len([f for f, speech in ring_buffer if not speech]) + # If more than 90% of the frames in the ring buffer are + # unvoiced, then enter NOTTRIGGERED and yield whatever + # audio we've collected. + if num_unvoiced > 0.9 * ring_buffer.maxlen: + # sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) + triggered = False + yield b"".join([f.bytes for f in voiced_frames]) + ring_buffer.clear() + voiced_frames = [] + # If we have any leftover voiced audio when we run out of input, + # yield it. + if voiced_frames: + yield b"".join([f.bytes for f in voiced_frames]) + + +def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300): + + vad = webrtcvad.Vad(int(aggressiveness)) + frames = list(frame_generator(30, audio, sample_rate)) + segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames) + + return segments diff --git a/docs/source/models/vits.md b/docs/source/models/vits.md index 5c0e92f6..0c303f7a 100644 --- a/docs/source/models/vits.md +++ b/docs/source/models/vits.md @@ -3,10 +3,15 @@ VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech ) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE, Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment -using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder. +using MAS, as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder. It is a feed-forward model with x67.12 real-time factor on a GPU. +🐸 YourTTS is a multi-speaker and multi-lingual TTS model that can perform voice conversion and zero-shot speaker adaptation. +It can also learn a new language or voice with a ~ 1 minute long audio clip. This is a big open gate for training +TTS models in low-resources languages. 🐸 YourTTS uses VITS as the backbone architecture coupled with a speaker encoder model. + ## Important resources & papers +- 🐸 YourTTS: https://arxiv.org/abs/2112.02418 - VITS: https://arxiv.org/pdf/2106.06103.pdf - Neural Spline Flows: https://arxiv.org/abs/1906.04032 - Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py index 9ba42fb9..4855886e 100644 --- a/notebooks/dataset_analysis/analyze.py +++ b/notebooks/dataset_analysis/analyze.py @@ -180,7 +180,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): plt.figure() plt.rcParams["figure.figsize"] = (50, 20) - barplot = sns.barplot(x, y) + barplot = sns.barplot(x=x, y=y) if save_path: fig = barplot.get_figure() fig.savefig(os.path.join(save_path, "phoneme_dist")) diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py new file mode 100644 index 00000000..be4747df --- /dev/null +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -0,0 +1,130 @@ +import os +from glob import glob + +from TTS.config.shared_configs import BaseAudioConfig +from TTS.trainer import Trainer, TrainingArgs +from TTS.tts.configs.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio import AudioProcessor + +output_path = os.path.dirname(os.path.abspath(__file__)) + +mailabs_path = "/home/julian/workspace/mailabs/**" +dataset_paths = glob(mailabs_path) +dataset_config = [ + BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1]) + for path in dataset_paths +] + +audio_config = BaseAudioConfig( + sample_rate=16000, + win_length=1024, + hop_length=256, + num_mels=80, + preemphasis=0.0, + ref_level_db=20, + log_func="np.log", + do_trim_silence=False, + trim_db=23.0, + mel_fmin=0, + mel_fmax=None, + spec_gain=1.0, + signal_norm=True, + do_amp_to_db_linear=False, + resample=False, +) + +vitsArgs = VitsArgs( + use_language_embedding=True, + embedded_language_dim=4, + use_speaker_embedding=True, + use_sdp=False, +) + +config = VitsConfig( + model_args=vitsArgs, + audio=audio_config, + run_name="vits_vctk", + use_speaker_embedding=True, + batch_size=32, + eval_batch_size=16, + batch_group_size=0, + num_loader_workers=4, + num_eval_loader_workers=4, + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="multilingual_cleaners", + use_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + compute_input_seq_cache=True, + print_step=25, + use_language_weighted_sampler=True, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + min_seq_len=32 * 256 * 4, + max_seq_len=160000, + output_path=output_path, + datasets=dataset_config, + characters={ + "pad": "_", + "eos": "&", + "bos": "*", + "characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„", + "punctuations": "!¡'(),-.:;¿? ", + "phonemes": None, + "unique": True, + }, + test_sentences=[ + [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "mary_ann", + None, + "en_US", + ], + [ + "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", + "ezwa", + None, + "fr_FR", + ], + ["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"], + ["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"], + ], +) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init speaker manager for multi-speaker training +# it maps speaker-id to speaker-name in the model and data-loader +speaker_manager = SpeakerManager() +speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +config.model_args.num_speakers = speaker_manager.num_speakers + +language_manager = LanguageManager(config=config) +config.model_args.num_languages = language_manager.num_languages + +# init model +model = Vits(config, speaker_manager, language_manager) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) +trainer.fit() diff --git a/requirements.txt b/requirements.txt index 3ec33ceb..ddb6def9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,5 @@ unidic-lite==1.0.8 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 fsspec>=2021.04.0 pyworld +webrtcvad +torchaudio diff --git a/tests/__init__.py b/tests/__init__.py index 45aee23a..0a0c3379 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -38,3 +38,14 @@ def run_cli(command): def get_test_data_config(): return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv") + + +def assertHasAttr(test_obj, obj, intendedAttr): + # from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead + testBool = hasattr(obj, intendedAttr) + test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}") + + +def assertHasNotAttr(test_obj, obj, intendedAttr): + testBool = hasattr(obj, intendedAttr) + test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}") diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py new file mode 100644 index 00000000..fa0abe4b --- /dev/null +++ b/tests/aux_tests/test_find_unique_phonemes.py @@ -0,0 +1,80 @@ +import os +import unittest + +import torch + +from tests import get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig + +torch.manual_seed(1) + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") + +dataset_config_en = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en", +) + +dataset_config_pt = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="pt-br", +) + +# pylint: disable=protected-access +class TestFindUniquePhonemes(unittest.TestCase): + @staticmethod + def test_espeak_phonemes(): + # prepare the config + config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + datasets=[dataset_config_en, dataset_config_pt], + ) + config.save_json(config_path) + + # run test + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"') + + @staticmethod + def test_no_espeak_phonemes(): + # prepare the config + config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + datasets=[dataset_config_en, dataset_config_pt], + ) + config.save_json(config_path) + + # run test + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"') diff --git a/tests/aux_tests/test_remove_silence_vad_script.py b/tests/aux_tests/test_remove_silence_vad_script.py new file mode 100644 index 00000000..c934e065 --- /dev/null +++ b/tests/aux_tests/test_remove_silence_vad_script.py @@ -0,0 +1,29 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path, get_tests_output_path, run_cli + +torch.manual_seed(1) + +# pylint: disable=protected-access +class TestRemoveSilenceVAD(unittest.TestCase): + @staticmethod + def test(): + # set paths + wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs") + output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/") + output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/") + + # resample audios + run_cli( + f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000' + ) + + # run test + run_cli( + f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"' + ) + run_cli(f'rm -rf "{output_resample_path}"') + run_cli(f'rm -rf "{output_path}"') diff --git a/tests/aux_tests/test_speaker_encoder.py b/tests/aux_tests/test_speaker_encoder.py index 3c897aa9..97b3b92f 100644 --- a/tests/aux_tests/test_speaker_encoder.py +++ b/tests/aux_tests/test_speaker_encoder.py @@ -13,7 +13,7 @@ file_path = get_tests_input_path() class LSTMSpeakerEncoderTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): - dummy_input = T.rand(4, 20, 80) # B x T x D + dummy_input = T.rand(4, 80, 20) # B x D x T dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) # computing d vectors @@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase): assert output.type() == "torch.FloatTensor" assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" # compute d for a given batch - dummy_input = T.rand(1, 240, 80) # B x T x D + dummy_input = T.rand(1, 80, 240) # B x T x D output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5) assert output.shape[0] == 1 assert output.shape[1] == 256 @@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase): class ResNetSpeakerEncoderTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): - dummy_input = T.rand(4, 20, 80) # B x T x D + dummy_input = T.rand(4, 80, 20) # B x D x T dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256) # computing d vectors @@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase): assert output.type() == "torch.FloatTensor" assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" # compute d for a given batch - dummy_input = T.rand(1, 240, 80) # B x T x D + dummy_input = T.rand(1, 80, 240) # B x D x T output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10) assert output.shape[0] == 1 assert output.shape[1] == 256 diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py index baa50749..fff49b13 100644 --- a/tests/aux_tests/test_speaker_manager.py +++ b/tests/aux_tests/test_speaker_manager.py @@ -6,7 +6,7 @@ import torch from tests import get_tests_input_path from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_model +from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.speaker_encoder.utils.io import save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase): config.audio.resample = True # create a dummy speaker encoder - model = setup_model(config) + model = setup_speaker_encoder_model(config) save_checkpoint(model, None, None, get_tests_input_path(), 0) # load audio processor and speaker encoder @@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase): # load a sample audio and compute embedding waveform = ap.load_wav(sample_wav_path) mel = ap.melspectrogram(waveform) - d_vector = manager.compute_d_vector(mel.T) + d_vector = manager.compute_d_vector(mel) assert d_vector.shape[1] == 256 # compute d_vector directly from an input file diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 8a20c261..19c2e8f7 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase): def _create_dataloader(self, batch_size, r, bgs): items = ljspeech(c.data_path, "metadata.csv") + + # add a default language because now the TTSDataset expect a language + language = "" + items = [[*item, language] for item in items] + dataset = TTSDataset( r, c.text_cleaner, diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py new file mode 100644 index 00000000..3d8d6c75 --- /dev/null +++ b/tests/data_tests/test_samplers.py @@ -0,0 +1,58 @@ +import functools + +import torch + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.languages import get_language_weighted_sampler + +# Fixing random state to avoid random fails +torch.manual_seed(0) + +dataset_config_en = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en", +) + +dataset_config_pt = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="pt-br", +) + +# Adding the EN samples twice to create an unbalanced dataset +train_samples, eval_samples = load_tts_samples( + [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True +) + + +def is_balanced(lang_1, lang_2): + return 0.85 < lang_1 / lang_2 < 1.2 + + +random_sampler = torch.utils.data.RandomSampler(train_samples) +ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) +en, pt = 0, 0 +for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + +assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" + +weighted_sampler = get_language_weighted_sampler(train_samples) +ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) +en, pt = 0, 0 +for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + +assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced" diff --git a/tests/inputs/language_ids.json b/tests/inputs/language_ids.json new file mode 100644 index 00000000..27bb1520 --- /dev/null +++ b/tests/inputs/language_ids.json @@ -0,0 +1,5 @@ +{ + "en": 0, + "fr-fr": 1, + "pt-br": 2 +} \ No newline at end of file diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py new file mode 100644 index 00000000..4274d947 --- /dev/null +++ b/tests/tts_tests/test_vits.py @@ -0,0 +1,240 @@ +import os +import unittest + +import torch + +from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path +from TTS.config import load_config +from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.utils.speakers import SpeakerManager + +LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") +SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") + + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# pylint: disable=no-self-use +class TestVits(unittest.TestCase): + def test_init_multispeaker(self): + num_speakers = 10 + args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) + model = Vits(args) + assertHasAttr(self, model, "emb_g") + + args = VitsArgs(num_speakers=0, use_speaker_embedding=True) + model = Vits(args) + assertHasNotAttr(self, model, "emb_g") + + args = VitsArgs(num_speakers=10, use_speaker_embedding=False) + model = Vits(args) + assertHasNotAttr(self, model, "emb_g") + + args = VitsArgs(d_vector_dim=101, use_d_vector_file=True) + model = Vits(args) + self.assertEqual(model.embedded_speaker_dim, 101) + + def test_init_multilingual(self): + args = VitsArgs(language_ids_file=None, use_language_embedding=False) + model = Vits(args) + self.assertEqual(model.language_manager, None) + self.assertEqual(model.embedded_language_dim, 0) + self.assertEqual(model.emb_l, None) + + args = VitsArgs(language_ids_file=LANG_FILE) + model = Vits(args) + self.assertNotEqual(model.language_manager, None) + self.assertEqual(model.embedded_language_dim, 0) + self.assertEqual(model.emb_l, None) + + args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True) + model = Vits(args) + self.assertNotEqual(model.language_manager, None) + self.assertEqual(model.embedded_language_dim, args.embedded_language_dim) + self.assertNotEqual(model.emb_l, None) + + args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102) + model = Vits(args) + self.assertNotEqual(model.language_manager, None) + self.assertEqual(model.embedded_language_dim, args.embedded_language_dim) + self.assertNotEqual(model.emb_l, None) + + def test_get_aux_input(self): + aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None} + args = VitsArgs() + model = Vits(args) + aux_out = model.get_aux_input(aux_input) + + speaker_id = torch.randint(10, (1,)) + language_id = torch.randint(10, (1,)) + d_vector = torch.rand(1, 128) + aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} + aux_out = model.get_aux_input(aux_input) + self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) + self.assertEqual(aux_out["language_ids"].shape, language_id.shape) + self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape) + + def test_voice_conversion(self): + num_speakers = 10 + spec_len = 101 + spec_effective_len = 50 + + args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) + model = Vits(args) + + ref_inp = torch.randn(1, spec_len, 513) + ref_inp_len = torch.randint(1, spec_effective_len, (1,)) + ref_spk_id = torch.randint(1, num_speakers, (1,)) + tgt_spk_id = torch.randint(1, num_speakers, (1,)) + o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id) + + self.assertEqual(o_hat.shape, (1, 1, spec_len * 256)) + self.assertEqual(y_mask.shape, (1, 1, spec_len)) + self.assertEqual(y_mask.sum(), ref_inp_len[0]) + self.assertEqual(z.shape, (1, args.hidden_channels, spec_len)) + self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len)) + self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len)) + + def _init_inputs(self, config): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) + input_lengths[-1] = 128 + spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device) + spec_lengths = torch.randint(20, 30, (8,)).long().to(device) + spec_lengths[-1] = spec.size(2) + waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device) + return input_dummy, input_lengths, spec, spec_lengths, waveform + + def _check_forward_outputs(self, config, output_dict, encoder_config=None): + self.assertEqual( + output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) + self.assertEqual(output_dict["alignments"].shape, (8, 128, 30)) + self.assertEqual(output_dict["alignments"].max(), 1) + self.assertEqual(output_dict["alignments"].min(), 0) + self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual( + output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) + if encoder_config: + self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + else: + self.assertEqual(output_dict["gt_spk_emb"], None) + self.assertEqual(output_dict["syn_spk_emb"], None) + + def test_forward(self): + num_speakers = 0 + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) + config.model_args.spec_segment_size = 10 + input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + model = Vits(config).to(device) + output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) + self._check_forward_outputs(config, output_dict) + + def test_multispeaker_forward(self): + num_speakers = 10 + + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) + config.model_args.spec_segment_size = 10 + + input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) + + model = Vits(config).to(device) + output_dict = model.forward( + input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids} + ) + self._check_forward_outputs(config, output_dict) + + def test_multilingual_forward(self): + num_speakers = 10 + num_langs = 3 + + args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + + input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + + model = Vits(config).to(device) + output_dict = model.forward( + input_dummy, + input_lengths, + spec, + spec_lengths, + waveform, + aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}, + ) + self._check_forward_outputs(config, output_dict) + + def test_secl_forward(self): + num_speakers = 10 + num_langs = 3 + + speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) + speaker_encoder_config.model_params["use_torch_spec"] = True + speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device) + speaker_manager = SpeakerManager() + speaker_manager.speaker_encoder = speaker_encoder + + args = VitsArgs( + language_ids_file=LANG_FILE, + use_language_embedding=True, + spec_segment_size=10, + use_speaker_encoder_as_loss=True, + ) + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + config.audio.sample_rate = 16000 + + input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + + model = Vits(config, speaker_manager=speaker_manager).to(device) + output_dict = model.forward( + input_dummy, + input_lengths, + spec, + spec_lengths, + waveform, + aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}, + ) + self._check_forward_outputs(config, output_dict, speaker_encoder_config) + + def test_inference(self): + num_speakers = 0 + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + model = Vits(config).to(device) + _ = model.inference(input_dummy) + + def test_multispeaker_inference(self): + num_speakers = 10 + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) + model = Vits(config).to(device) + _ = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + + def test_multilingual_inference(self): + num_speakers = 10 + num_langs = 3 + args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) + config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (1,)).long().to(device) + model = Vits(config).to(device) + _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py new file mode 100644 index 00000000..213669f5 --- /dev/null +++ b/tests/tts_tests/test_vits_d-vectors_train.py @@ -0,0 +1,62 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-0"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_d_vector_file = True +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.d_vector_dim = 256 + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py new file mode 100644 index 00000000..1ca57d93 --- /dev/null +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -0,0 +1,91 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +dataset_config_en = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en", +) + +dataset_config_pt = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="pt-br", +) + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=False, + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-0", None, "en"], + ["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"], + ], + datasets=[dataset_config_en, dataset_config_pt], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multilingual mode +config.model_args.use_language_embedding = True +config.use_language_embedding = True + +# deactivate multispeaker mode +config.model_args.use_speaker_embedding = False +config.use_speaker_embedding = False + +# active multispeaker d-vec mode +config.model_args.use_d_vector_file = True +config.use_d_vector_file = True +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.d_vector_dim = 256 +config.d_vector_dim = 256 + +# duration predictor +config.model_args.use_sdp = True +config.use_sdp = True + +# deactivate language sampler +config.use_language_weighted_sampler = False + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_train.py new file mode 100644 index 00000000..50cccca5 --- /dev/null +++ b/tests/tts_tests/test_vits_multilingual_train.py @@ -0,0 +1,88 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +dataset_config_en = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en", +) + +dataset_config_pt = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="pt-br", +) + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech", None, "en"], + ["Be a voice, not an echo.", "ljspeech", None, "pt-br"], + ], + datasets=[dataset_config_en, dataset_config_pt], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multilingual mode +config.model_args.use_language_embedding = True +config.use_language_embedding = True +# active multispeaker mode +config.model_args.use_speaker_embedding = True +config.use_speaker_embedding = True + +# deactivate multispeaker d-vec mode +config.model_args.use_d_vector_file = False +config.use_d_vector_file = False + +# duration predictor +config.model_args.use_sdp = False +config.use_sdp = False + +# active language sampler +config.use_language_weighted_sampler = True + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py new file mode 100644 index 00000000..6cc1dabd --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -0,0 +1,63 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = None +config.model_args.d_vector_dim = 256 + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index 6398955e..607f7b29 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -25,7 +25,7 @@ config = VitsConfig( print_step=1, print_eval=True, test_sentences=[ - "Be a voice, not an echo.", + ["Be a voice, not an echo."], ], ) config.audio.do_trim_silence = True diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 886d1bb6..63d9e7ca 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -4,6 +4,7 @@ import os import shutil from tests import get_tests_output_path, run_cli +from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.manage import ModelManager @@ -17,21 +18,30 @@ def test_run_all_models(): manager = ModelManager(output_prefix=get_tests_output_path()) model_names = manager.list_models() for model_name in model_names: + print(f"\n > Run - {model_name}") model_path, _, _ = manager.download_model(model_name) if "tts_models" in model_name: local_download_dir = os.path.dirname(model_path) # download and run the model speaker_files = glob.glob(local_download_dir + "/speaker*") + language_files = glob.glob(local_download_dir + "/language*") + language_id = "" if len(speaker_files) > 0: # multi-speaker model if "speaker_ids" in speaker_files[0]: speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) elif "speakers" in speaker_files[0]: speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) + + # multi-lingual model - Assuming multi-lingual models are also multi-speaker + if len(language_files) > 0 and "language_ids" in language_files[0]: + language_manager = LanguageManager(language_ids_file_path=language_files[0]) + language_id = language_manager.language_names[0] + speaker_id = list(speaker_manager.speaker_ids.keys())[0] run_cli( f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"' + f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" ' ) else: # single-speaker model