From e778bad626d94457833853bbfbe286b5d1a442fb Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Thu, 6 Jan 2022 15:07:27 +0100 Subject: [PATCH 001/214] Add argument to enable dp speaker conditioning --- TTS/tts/models/vits.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b2e4be9e..cb349ca2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -171,6 +171,9 @@ class VitsArgs(Coqpit): speaker_encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + freeze_encoder (bool): Freeze the encoder weigths during training. Defaults to False. @@ -233,6 +236,7 @@ class VitsArgs(Coqpit): use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False @@ -349,7 +353,7 @@ class Vits(BaseTTS): 3, args.dropout_p_duration_predictor, 4, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: @@ -358,7 +362,7 @@ class Vits(BaseTTS): 256, 3, args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) @@ -595,12 +599,15 @@ class Vits(BaseTTS): # duration predictor attn_durations = attn.sum(3) + g_dp = None + if self.args.condition_dp_on_speaker: + g_dp = g.detach() if self.args.detach_dp_input and g is not None else g if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, - g=g.detach() if self.args.detach_dp_input and g is not None else g, + g=g_dp, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) @@ -609,7 +616,7 @@ class Vits(BaseTTS): log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, - g=g.detach() if self.args.detach_dp_input and g is not None else g, + g=g_dp, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) @@ -685,10 +692,10 @@ class Vits(BaseTTS): if self.args.use_sdp: logw = self.duration_predictor( - x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb ) else: - logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) + logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) From c7f5e005e17cd80f9aeba5f5b119430dfa193c4f Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 4 Jan 2022 10:06:57 +0100 Subject: [PATCH 002/214] Compute embedding for new audios only --- TTS/bin/compute_embeddings.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 83a5aeae..2ac18651 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -29,6 +29,7 @@ parser.add_argument( help="Path to dataset config file.", ) parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") +parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True) @@ -40,7 +41,7 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli wav_files = meta_data_train + meta_data_eval speaker_manager = SpeakerManager( - encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda ) # compute speaker embeddings @@ -52,11 +53,15 @@ for idx, wav_file in enumerate(tqdm(wav_files)): else: speaker_name = None - # extract the embedding - embedd = speaker_manager.compute_d_vector_from_clip(wav_file) + wav_file_name = os.path.basename(wav_file) + if args.old_file is not None and wav_file_name in speaker_manager.clip_ids: + # get the embedding from the old file + embedd = speaker_manager.get_d_vector_by_clip(wav_file_name) + else: + # extract the embedding + embedd = speaker_manager.compute_d_vector_from_clip(wav_file) # create speaker_mapping if target dataset is defined - wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name]["name"] = speaker_name speaker_mapping[wav_file_name]["embedding"] = embedd From 0860d73cf804a99eb89e08133c1a6ee3f1383f4f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 10 Feb 2022 12:14:54 -0300 Subject: [PATCH 003/214] Remove Tensorflow requeriment (#1225) * Remove TF modules * Remove TF unit tests * Remove TF vocoder modules * Remove TF convert scripts * Remove TF requirement * Remove the Docs TF instructions * Remove TF inference support --- Makefile | 1 - README.md | 11 +- TTS/bin/convert_melgan_tflite.py | 25 -- TTS/bin/convert_melgan_torch_to_tf.py | 105 ----- TTS/bin/convert_tacotron2_tflite.py | 30 -- TTS/bin/convert_tacotron2_torch_to_tf.py | 187 -------- TTS/tts/layers/tacotron/tacotron2.py | 1 - TTS/tts/tf/README.md | 20 - TTS/tts/tf/__init__.py | 0 TTS/tts/tf/layers/tacotron/__init__.py | 0 TTS/tts/tf/layers/tacotron/common_layers.py | 301 ------------- TTS/tts/tf/layers/tacotron/tacotron2.py | 322 ------------- TTS/tts/tf/models/tacotron2.py | 116 ----- TTS/tts/tf/utils/convert_torch_to_tf_utils.py | 87 ---- TTS/tts/tf/utils/generic_utils.py | 105 ----- TTS/tts/tf/utils/io.py | 45 -- TTS/tts/tf/utils/tf_utils.py | 8 - TTS/tts/tf/utils/tflite.py | 27 -- TTS/tts/utils/synthesis.py | 117 +---- TTS/vocoder/tf/layers/melgan.py | 54 --- TTS/vocoder/tf/layers/pqmf.py | 60 --- TTS/vocoder/tf/models/melgan_generator.py | 133 ------ .../tf/models/multiband_melgan_generator.py | 65 --- TTS/vocoder/tf/utils/__init__.py | 0 .../tf/utils/convert_torch_to_tf_utils.py | 47 -- TTS/vocoder/tf/utils/generic_utils.py | 36 -- TTS/vocoder/tf/utils/io.py | 31 -- TTS/vocoder/tf/utils/tflite.py | 27 -- docs/source/converting_torch_to_tf.md | 21 - docs/source/index.md | 1 - docs/source/installation.md | 6 - ...l_Converting_PyTorch_to_TF_to_TFlite.ipynb | 425 ------------------ requirements.tf.txt | 1 - setup.py | 5 +- tests/tts_tests/test_tacotron2_tf_model.py | 156 ------- .../test_vocoder_tf_melgan_generator.py | 19 - tests/vocoder_tests/test_vocoder_tf_pqmf.py | 31 -- 37 files changed, 19 insertions(+), 2607 deletions(-) delete mode 100644 TTS/bin/convert_melgan_tflite.py delete mode 100644 TTS/bin/convert_melgan_torch_to_tf.py delete mode 100644 TTS/bin/convert_tacotron2_tflite.py delete mode 100644 TTS/bin/convert_tacotron2_torch_to_tf.py delete mode 100644 TTS/tts/tf/README.md delete mode 100644 TTS/tts/tf/__init__.py delete mode 100644 TTS/tts/tf/layers/tacotron/__init__.py delete mode 100644 TTS/tts/tf/layers/tacotron/common_layers.py delete mode 100644 TTS/tts/tf/layers/tacotron/tacotron2.py delete mode 100644 TTS/tts/tf/models/tacotron2.py delete mode 100644 TTS/tts/tf/utils/convert_torch_to_tf_utils.py delete mode 100644 TTS/tts/tf/utils/generic_utils.py delete mode 100644 TTS/tts/tf/utils/io.py delete mode 100644 TTS/tts/tf/utils/tf_utils.py delete mode 100644 TTS/tts/tf/utils/tflite.py delete mode 100644 TTS/vocoder/tf/layers/melgan.py delete mode 100644 TTS/vocoder/tf/layers/pqmf.py delete mode 100644 TTS/vocoder/tf/models/melgan_generator.py delete mode 100644 TTS/vocoder/tf/models/multiband_melgan_generator.py delete mode 100644 TTS/vocoder/tf/utils/__init__.py delete mode 100644 TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py delete mode 100644 TTS/vocoder/tf/utils/generic_utils.py delete mode 100644 TTS/vocoder/tf/utils/io.py delete mode 100644 TTS/vocoder/tf/utils/tflite.py delete mode 100644 docs/source/converting_torch_to_tf.md delete mode 100644 notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb delete mode 100644 requirements.tf.txt delete mode 100644 tests/tts_tests/test_tacotron2_tf_model.py delete mode 100644 tests/vocoder_tests/test_vocoder_tf_melgan_generator.py delete mode 100644 tests/vocoder_tests/test_vocoder_tf_pqmf.py diff --git a/Makefile b/Makefile index 32b4638b..2632dbab 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,6 @@ system-deps: ## install linux system deps dev-deps: ## install development deps pip install -r requirements.dev.txt - pip install -r requirements.tf.txt doc-deps: ## install docs dependencies pip install -r docs/requirements.txt diff --git a/README.md b/README.md index 4686ac67..e7774888 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,6 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models - Detailed training logs on the terminal and Tensorboard. - Support for Multi-speaker TTS. - Efficient, flexible, lightweight but feature complete `Trainer API`. -- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference. - Released and read-to-use models. - Tools to curate Text2Speech datasets under```dataset_analysis```. - Utilities to use and test your models. @@ -113,17 +112,11 @@ If you are only interested in [synthesizing speech](https://tts.readthedocs.io/e pip install TTS ``` -By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra. - -```bash -pip install TTS[tf] -``` - If you plan to code or train models, clone 🐸TTS and install it locally. ```bash git clone https://github.com/coqui-ai/TTS -pip install -e .[all,dev,notebooks,tf] # Select the relevant extras +pip install -e .[all,dev,notebooks] # Select the relevant extras ``` If you are on Ubuntu (Debian), you can also run following commands for installation. @@ -204,12 +197,10 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht |- train*.py (train your target model.) |- distribute.py (train your TTS model using Multiple GPUs.) |- compute_statistics.py (compute dataset statistics for normalization.) - |- convert*.py (convert target torch model to TF.) |- ... |- tts/ (text to speech models) |- layers/ (model layer definitions) |- models/ (model definitions) - |- tf/ (Tensorflow 2 utilities and model implementations) |- utils/ (model specific utilities.) |- speaker_encoder/ (Speaker Encoder models.) |- (same) diff --git a/TTS/bin/convert_melgan_tflite.py b/TTS/bin/convert_melgan_tflite.py deleted file mode 100644 index a3a3fb66..00000000 --- a/TTS/bin/convert_melgan_tflite.py +++ /dev/null @@ -1,25 +0,0 @@ -# Convert Tensorflow Tacotron2 model to TF-Lite binary - -import argparse - -from TTS.utils.io import load_config -from TTS.vocoder.tf.utils.generic_utils import setup_generator -from TTS.vocoder.tf.utils.io import load_checkpoint -from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite - -parser = argparse.ArgumentParser() -parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to tflite output binary.") -args = parser.parse_args() - -# Set constants -CONFIG = load_config(args.config_path) - -# load the model -model = setup_generator(CONFIG) -model.build_inference() -model = load_checkpoint(model, args.tf_model) - -# create tflite model -tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path) diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py deleted file mode 100644 index c1fb8498..00000000 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import os -from difflib import SequenceMatcher - -import numpy as np -import tensorflow as tf -import torch - -from TTS.utils.io import load_config, load_fsspec -from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, - convert_tf_name, - transfer_weights_torch_to_tf, -) -from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator -from TTS.vocoder.tf.utils.io import save_checkpoint -from TTS.vocoder.utils.generic_utils import setup_generator - -# prevent GPU use -os.environ["CUDA_VISIBLE_DEVICES"] = "" - -# define args -parser = argparse.ArgumentParser() -parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") -args = parser.parse_args() - -# load model config -config_path = args.config_path -c = load_config(config_path) -num_speakers = 0 - -# init torch model -model = setup_generator(c) -checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) -state_dict = checkpoint["model"] -model.load_state_dict(state_dict) -model.remove_weight_norm() -state_dict = model.state_dict() - -# init tf model -model_tf = setup_tf_generator(c) - -common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" -# get tf_model graph by passing an input -# B x D x T -dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32) -mel_pred = model_tf(dummy_input, training=False) - -# get tf variables -tf_vars = model_tf.weights - -# match variable names with fuzzy logic -torch_var_names = list(state_dict.keys()) -tf_var_names = [we.name for we in model_tf.weights] -var_map = [] -for tf_name in tf_var_names: - # skip re-mapped layer names - if tf_name in [name[0] for name in var_map]: - continue - tf_name_edited = convert_tf_name(tf_name) - ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] - max_idx = np.argmax(ratios) - matching_name = torch_var_names[max_idx] - del torch_var_names[max_idx] - var_map.append((tf_name, matching_name)) - -# pass weights -tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) - -# Compare TF and TORCH models -# check embedding outputs -model.eval() -dummy_input_torch = torch.ones((1, 80, 10)) -dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) -dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1]) -dummy_input_tf = tf.expand_dims(dummy_input_tf, 2) - -out_torch = model.layers[0](dummy_input_torch) -out_tf = model_tf.model_layers[0](dummy_input_tf) -out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] - -assert compare_torch_tf(out_torch, out_tf_) < 1e-5 - -for i in range(1, len(model.layers)): - print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}") - out_torch = model.layers[i](out_torch) - out_tf = model_tf.model_layers[i](out_tf) - out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] - diff = compare_torch_tf(out_torch, out_tf_) - assert diff < 1e-5, diff - -torch.manual_seed(0) -dummy_input_torch = torch.rand((1, 80, 100)) -dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) -model.inference_padding = 0 -model_tf.inference_padding = 0 -output_torch = model.inference(dummy_input_torch) -output_tf = model_tf(dummy_input_tf, training=False) -assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf) - -# save tf model -save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path) -print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/convert_tacotron2_tflite.py b/TTS/bin/convert_tacotron2_tflite.py deleted file mode 100644 index 327d0ae8..00000000 --- a/TTS/bin/convert_tacotron2_tflite.py +++ /dev/null @@ -1,30 +0,0 @@ -# Convert Tensorflow Tacotron2 model to TF-Lite binary - -import argparse - -from TTS.tts.tf.utils.generic_utils import setup_model -from TTS.tts.tf.utils.io import load_checkpoint -from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite -from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config - -parser = argparse.ArgumentParser() -parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to tflite output binary.") -args = parser.parse_args() - -# Set constants -CONFIG = load_config(args.config_path) - -# load the model -c = CONFIG -num_speakers = 0 -num_chars = len(phonemes) if c.use_phonemes else len(symbols) -model = setup_model(num_chars, num_speakers, c, enable_tflite=True) -model.build_inference() -model = load_checkpoint(model, args.tf_model) -model.decoder.set_max_decoder_steps(1000) - -# create tflite model -tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path) diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py deleted file mode 100644 index 78c6b362..00000000 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ /dev/null @@ -1,187 +0,0 @@ -import argparse -import os -import sys -from difflib import SequenceMatcher -from pprint import pprint - -import numpy as np -import tensorflow as tf -import torch - -from TTS.tts.models import setup_model -from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf -from TTS.tts.tf.utils.generic_utils import save_checkpoint -from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config, load_fsspec - -sys.path.append("/home/erogol/Projects") -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -parser = argparse.ArgumentParser() -parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") -args = parser.parse_args() - -# load model config -config_path = args.config_path -c = load_config(config_path) -num_speakers = 0 - -# init torch model -model = setup_model(c) -checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) -state_dict = checkpoint["model"] -model.load_state_dict(state_dict) - -# init tf model -num_chars = len(phonemes) if c.use_phonemes else len(symbols) -model_tf = Tacotron2( - num_chars=num_chars, - num_speakers=num_speakers, - r=model.decoder.r, - out_channels=c.audio["num_mels"], - decoder_output_dim=c.audio["num_mels"], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, -) - -# set initial layer mapping - these are not captured by the below heuristic approach -# TODO: set layer names so that we can remove these manual matching -common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" -var_map = [ - ("embedding/embeddings:0", "embedding.weight"), - ("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"), - ("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"), - ("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"), - ("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"), - ("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")), - ( - "encoder/lstm/backward_lstm/lstm_cell_2/bias:0", - ("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"), - ), - ("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"), - ("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"), - ("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"), -] - -# %% -# get tf_model graph -model_tf.build_inference() - -# get tf variables -tf_vars = model_tf.weights - -# match variable names with fuzzy logic -torch_var_names = list(state_dict.keys()) -tf_var_names = [we.name for we in model_tf.weights] -for tf_name in tf_var_names: - # skip re-mapped layer names - if tf_name in [name[0] for name in var_map]: - continue - tf_name_edited = convert_tf_name(tf_name) - ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] - max_idx = np.argmax(ratios) - matching_name = torch_var_names[max_idx] - del torch_var_names[max_idx] - var_map.append((tf_name, matching_name)) - -pprint(var_map) -pprint(torch_var_names) - -# pass weights -tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) - -# Compare TF and TORCH models -# %% -# check embedding outputs -model.eval() -input_ids = torch.randint(0, 24, (1, 128)).long() - -o_t = model.embedding(input_ids) -o_tf = model_tf.embedding(input_ids.detach().numpy()) -assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() - -# compare encoder outputs -oo_en = model.encoder.inference(o_t.transpose(1, 2)) -ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) -assert compare_torch_tf(oo_en, ooo_en) < 1e-5 - -# pylint: disable=redefined-builtin -# compare decoder.attention_rnn -inp = torch.rand([1, 768]) -inp_tf = inp.numpy() -model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access -output, cell_state = model.decoder.attention_rnn(inp) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) -assert compare_torch_tf(output, output_tf).mean() < 1e-5 - -query = output -inputs = torch.rand([1, 128, 512]) -query_tf = query.detach().numpy() -inputs_tf = inputs.numpy() - -# compare decoder.attention -model.decoder.attention.init_states(inputs) -processes_inputs = model.decoder.attention.preprocess_inputs(inputs) -loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) -context = model.decoder.attention(query, inputs, processes_inputs, None) - -attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1] -model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) -loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states) -context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False) - -assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5 -assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5 -assert compare_torch_tf(context, context_tf) < 1e-5 - -# compare decoder.decoder_rnn -input = torch.rand([1, 1536]) -input_tf = input.numpy() -model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access -output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) -assert abs(input - input_tf).mean() < 1e-5 -assert compare_torch_tf(output, output_tf).mean() < 1e-5 - -# compare decoder.linear_projection -input = torch.rand([1, 1536]) -input_tf = input.numpy() -output = model.decoder.linear_projection(input) -output_tf = model_tf.decoder.linear_projection(input_tf, training=False) -assert compare_torch_tf(output, output_tf) < 1e-5 - -# compare decoder outputs -model.decoder.max_decoder_steps = 100 -model_tf.decoder.set_max_decoder_steps(100) -output, align, stop = model.decoder.inference(oo_en) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False) -assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4 - -# compare the whole model output -outputs_torch = model.inference(input_ids) -outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) -print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean()) -assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 -assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 - -# %% -# save tf model -save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path) -print(" > Model conversion is successfully completed :).") diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index 9c33623e..c79b7099 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -6,7 +6,6 @@ from .attentions import init_attn from .common_layers import Linear, Prenet -# NOTE: linter has a problem with the current TF release # pylint: disable=no-value-for-parameter # pylint: disable=unexpected-keyword-arg class ConvBNBlock(nn.Module): diff --git a/TTS/tts/tf/README.md b/TTS/tts/tf/README.md deleted file mode 100644 index 0f9d58e9..00000000 --- a/TTS/tts/tf/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## Utilities to Convert Models to Tensorflow2 -Here there are experimental utilities to convert trained Torch models to Tensorflow (2.2>=). - -Converting Torch models to TF enables all the TF toolkit to be used for better deployment and device specific optimizations. - -Note that we do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome. - -To see how you can use TF model at inference, check the notebook. - -This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. - - -### Converting a Model -- Run ```convert_tacotron2_torch_to_tf.py --torch_model_path /path/to/torch/model.pth.tar --config_path /path/to/model/config.json --output_path /path/to/output/tf/model``` with the right arguments. - -### Known issues ans limitations -- We use a custom model load/save mechanism which enables us to store model related information with models weights. (Similar to Torch). However, it is prone to random errors. -- Current TF model implementation is slightly slower than Torch model. Hopefully, it'll get better with improving TF support for eager mode and ```tf.function```. -- TF implementation of Tacotron2 only supports regular Tacotron2 as in the paper. -- You can only convert models trained after TF model implementation since model layers has been updated in Torch model. diff --git a/TTS/tts/tf/__init__.py b/TTS/tts/tf/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/tts/tf/layers/tacotron/__init__.py b/TTS/tts/tf/layers/tacotron/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/tts/tf/layers/tacotron/common_layers.py b/TTS/tts/tf/layers/tacotron/common_layers.py deleted file mode 100644 index a6b87981..00000000 --- a/TTS/tts/tf/layers/tacotron/common_layers.py +++ /dev/null @@ -1,301 +0,0 @@ -import tensorflow as tf -from tensorflow import keras -from tensorflow.python.ops import math_ops - -# from tensorflow_addons.seq2seq import BahdanauAttention - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg - - -class Linear(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super().__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") - self.activation = keras.layers.ReLU() - - def call(self, x): - """ - shapes: - x: B x T x C - """ - return self.activation(self.linear_layer(x)) - - -class LinearBN(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super().__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") - self.batch_normalization = keras.layers.BatchNormalization( - axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization" - ) - self.activation = keras.layers.ReLU() - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - out = self.linear_layer(x) - out = self.batch_normalization(out, training=training) - return self.activation(out) - - -class Prenet(keras.layers.Layer): - def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs): - super().__init__(**kwargs) - self.prenet_type = prenet_type - self.prenet_dropout = prenet_dropout - self.linear_layers = [] - if prenet_type == "bn": - self.linear_layers += [ - LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) - ] - elif prenet_type == "original": - self.linear_layers += [ - Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) - ] - else: - raise RuntimeError(" [!] Unknown prenet type.") - if prenet_dropout: - self.dropout = keras.layers.Dropout(rate=0.5) - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - for linear in self.linear_layers: - if self.prenet_dropout: - x = self.dropout(linear(x), training=training) - else: - x = linear(x) - return x - - -def _sigmoid_norm(score): - attn_weights = tf.nn.sigmoid(score) - attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) - return attn_weights - - -class Attention(keras.layers.Layer): - """TODO: implement forward_attention - TODO: location sensitive attention - TODO: implement attention windowing""" - - def __init__( - self, - attn_dim, - use_loc_attn, - loc_attn_n_filters, - loc_attn_kernel_size, - use_windowing, - norm, - use_forward_attn, - use_trans_agent, - use_forward_attn_mask, - **kwargs, - ): - super().__init__(**kwargs) - self.use_loc_attn = use_loc_attn - self.loc_attn_n_filters = loc_attn_n_filters - self.loc_attn_kernel_size = loc_attn_kernel_size - self.use_windowing = use_windowing - self.norm = norm - self.use_forward_attn = use_forward_attn - self.use_trans_agent = use_trans_agent - self.use_forward_attn_mask = use_forward_attn_mask - self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer") - self.inputs_layer = tf.keras.layers.Dense( - attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer" - ) - self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer") - if use_loc_attn: - self.location_conv1d = keras.layers.Conv1D( - filters=loc_attn_n_filters, - kernel_size=loc_attn_kernel_size, - padding="same", - use_bias=False, - name="location_layer/location_conv1d", - ) - self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense") - if norm == "softmax": - self.norm_func = tf.nn.softmax - elif norm == "sigmoid": - self.norm_func = _sigmoid_norm - else: - raise ValueError("Unknown value for attention norm type") - - def init_states(self, batch_size, value_length): - states = [] - if self.use_loc_attn: - attention_cum = tf.zeros([batch_size, value_length]) - attention_old = tf.zeros([batch_size, value_length]) - states = [attention_cum, attention_old] - if self.use_forward_attn: - alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1) - states.append(alpha) - return tuple(states) - - def process_values(self, values): - """cache values for decoder iterations""" - # pylint: disable=attribute-defined-outside-init - self.processed_values = self.inputs_layer(values) - self.values = values - - def get_loc_attn(self, query, states): - """compute location attention, query layer and - unnorm. attention weights""" - attention_cum, attention_old = states[:2] - attn_cat = tf.stack([attention_old, attention_cum], axis=2) - - processed_query = self.query_layer(tf.expand_dims(query, 1)) - processed_attn = self.location_dense(self.location_conv1d(attn_cat)) - score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def get_attn(self, query): - """compute query layer and unnormalized attention weights""" - processed_query = self.query_layer(tf.expand_dims(query, 1)) - score = self.v(tf.nn.tanh(self.processed_values + processed_query)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def apply_score_masking(self, score, mask): # pylint: disable=no-self-use - """ignore sequence paddings""" - padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) - # Bias so padding positions do not contribute to attention distribution. - score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32) - return score - - def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use - # forward attention - fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) - # compute transition potentials - new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment - # renormalize attention weights - new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) - return new_alpha - - def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): - states = [] - if self.use_loc_attn: - states = [old_states[0] + scores_norm, attn_weights] - if self.use_forward_attn: - states.append(new_alpha) - return tuple(states) - - def call(self, query, states): - """ - shapes: - query: B x D - """ - if self.use_loc_attn: - score, _ = self.get_loc_attn(query, states) - else: - score, _ = self.get_attn(query) - - # TODO: masking - # if mask is not None: - # self.apply_score_masking(score, mask) - # attn_weights shape == (batch_size, max_length, 1) - - # normalize attention scores - scores_norm = self.norm_func(score) - attn_weights = scores_norm - - # apply forward attention - new_alpha = None - if self.use_forward_attn: - new_alpha = self.apply_forward_attention(attn_weights, states[-1]) - attn_weights = new_alpha - - # update states tuple - # states = (cum_attn_weights, attn_weights, new_alpha) - states = self.update_states(states, scores_norm, attn_weights, new_alpha) - - # context_vector shape after sum == (batch_size, hidden_size) - context_vector = tf.matmul( - tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False - ) - context_vector = tf.squeeze(context_vector, axis=1) - return context_vector, attn_weights, states - - -# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): -# dtype = processed_query.dtype -# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] -# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) - - -# class LocationSensitiveAttention(BahdanauAttention): -# def __init__(self, -# units, -# memory=None, -# memory_sequence_length=None, -# normalize=False, -# probability_fn="softmax", -# kernel_initializer="glorot_uniform", -# dtype=None, -# name="LocationSensitiveAttention", -# location_attention_filters=32, -# location_attention_kernel_size=31): - -# super( self).__init__(units=units, -# memory=memory, -# memory_sequence_length=memory_sequence_length, -# normalize=normalize, -# probability_fn='softmax', ## parent module default -# kernel_initializer=kernel_initializer, -# dtype=dtype, -# name=name) -# if probability_fn == 'sigmoid': -# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) -# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) -# self.location_dense = keras.layers.Dense(units, use_bias=False) -# # self.v = keras.layers.Dense(1, use_bias=True) - -# def _location_sensitive_score(self, processed_query, keys, processed_loc): -# processed_query = tf.expand_dims(processed_query, 1) -# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) - -# def _location_sensitive(self, alignment_cum, alignment_old): -# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) -# return self.location_dense(self.location_conv(alignment_cat)) - -# def _sigmoid_normalization(self, score): -# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) - -# # def _apply_masking(self, score, mask): -# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) -# # # Bias so padding positions do not contribute to attention distribution. -# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) -# # return score - -# def _calculate_attention(self, query, state): -# alignment_cum, alignment_old = state[:2] -# processed_query = self.query_layer( -# query) if self.query_layer else query -# processed_loc = self._location_sensitive(alignment_cum, alignment_old) -# score = self._location_sensitive_score( -# processed_query, -# self.keys, -# processed_loc) -# alignment = self.probability_fn(score, state) -# alignment_cum = alignment_cum + alignment -# state[0] = alignment_cum -# state[1] = alignment -# return alignment, state - -# def compute_context(self, alignments): -# expanded_alignments = tf.expand_dims(alignments, 1) -# context = tf.matmul(expanded_alignments, self.values) -# context = tf.squeeze(context, [1]) -# return context - -# # def call(self, query, state): -# # alignment, next_state = self._calculate_attention(query, state) -# # return alignment, next_state diff --git a/TTS/tts/tf/layers/tacotron/tacotron2.py b/TTS/tts/tf/layers/tacotron/tacotron2.py deleted file mode 100644 index 1fe679d2..00000000 --- a/TTS/tts/tf/layers/tacotron/tacotron2.py +++ /dev/null @@ -1,322 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet -from TTS.tts.tf.utils.tf_utils import shape_list - - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg -class ConvBNBlock(keras.layers.Layer): - def __init__(self, filters, kernel_size, activation, **kwargs): - super().__init__(**kwargs) - self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d") - self.batch_normalization = keras.layers.BatchNormalization( - axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization" - ) - self.dropout = keras.layers.Dropout(rate=0.5, name="dropout") - self.activation = keras.layers.Activation(activation, name="activation") - - def call(self, x, training=None): - o = self.convolution1d(x) - o = self.batch_normalization(o, training=training) - o = self.activation(o) - o = self.dropout(o, training=training) - return o - - -class Postnet(keras.layers.Layer): - def __init__(self, output_filters, num_convs, **kwargs): - super().__init__(**kwargs) - self.convolutions = [] - self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0")) - for idx in range(1, num_convs - 1): - self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}")) - self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}")) - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - return o - - -class Encoder(keras.layers.Layer): - def __init__(self, output_input_dim, **kwargs): - super().__init__(**kwargs) - self.convolutions = [] - for idx in range(3): - self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}")) - self.lstm = keras.layers.Bidirectional( - keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm" - ) - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - o = self.lstm(o) - return o - - -class Decoder(keras.layers.Layer): - # pylint: disable=unused-argument - def __init__( - self, - frame_dim, - r, - attn_type, - use_attn_win, - attn_norm, - prenet_type, - prenet_dropout, - use_forward_attn, - use_trans_agent, - use_forward_attn_mask, - use_location_attn, - attn_K, - separate_stopnet, - speaker_emb_dim, - enable_tflite, - **kwargs, - ): - super().__init__(**kwargs) - self.frame_dim = frame_dim - self.r_init = tf.constant(r, dtype=tf.int32) - self.r = tf.constant(r, dtype=tf.int32) - self.output_dim = r * self.frame_dim - self.separate_stopnet = separate_stopnet - self.enable_tflite = enable_tflite - - # layer constants - self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) - self.stop_thresh = tf.constant(0.5, dtype=tf.float32) - - # model dimensions - self.query_dim = 1024 - self.decoder_rnn_dim = 1024 - self.prenet_dim = 256 - self.attn_dim = 128 - self.p_attention_dropout = 0.1 - self.p_decoder_dropout = 0.1 - - self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet") - self.attention_rnn = keras.layers.LSTMCell( - self.query_dim, - use_bias=True, - name="attention_rnn", - ) - self.attention_rnn_dropout = keras.layers.Dropout(0.5) - - # TODO: implement other attn options - self.attention = Attention( - attn_dim=self.attn_dim, - use_loc_attn=True, - loc_attn_n_filters=32, - loc_attn_kernel_size=31, - use_windowing=False, - norm=attn_norm, - use_forward_attn=use_forward_attn, - use_trans_agent=use_trans_agent, - use_forward_attn_mask=use_forward_attn_mask, - name="attention", - ) - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn") - self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer") - self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer") - - def set_max_decoder_steps(self, new_max_steps): - self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) - - def set_r(self, new_r): - self.r = tf.constant(new_r, dtype=tf.int32) - self.output_dim = self.frame_dim * new_r - - def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): - zero_frame = tf.zeros([batch_size, self.frame_dim]) - zero_context = tf.zeros([batch_size, memory_dim]) - attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - attention_states = self.attention.init_states(batch_size, memory_length) - return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states - - def step(self, prenet_next, states, memory_seq_length=None, training=None): - _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states - attention_rnn_input = tf.concat([prenet_next, context_next], -1) - attention_rnn_output, attention_rnn_state = self.attention_rnn( - attention_rnn_input, attention_rnn_state, training=training - ) - attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) - context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) - decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) - decoder_rnn_output, decoder_rnn_state = self.decoder_rnn( - decoder_rnn_input, decoder_rnn_state, training=training - ) - decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) - linear_projection_input = tf.concat([decoder_rnn_output, context], -1) - output_frame = self.linear_projection(linear_projection_input, training=training) - stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) - stopnet_output = self.stopnet(stopnet_input, training=training) - output_frame = output_frame[:, : self.r * self.frame_dim] - states = ( - output_frame[:, self.frame_dim * (self.r - 1) :], - context, - attention_rnn_state, - decoder_rnn_state, - attention_states, - ) - return output_frame, stopnet_output, states, attention - - def decode(self, memory, states, frames, memory_seq_length=None): - B, _, _ = shape_list(memory) - num_iter = shape_list(frames)[1] // self.r - # init states - frame_zero = tf.expand_dims(states[0], 1) - frames = tf.concat([frame_zero, frames], axis=1) - outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) - attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) - # pre-computes - self.attention.process_values(memory) - prenet_output = self.prenet(frames, training=True) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): - prenet_next = prenet_output[:, step] - output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions - - _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop( - lambda *arg: True, - _body, - loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=num_iter, - ) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference(self, memory, states): - B, _, _ = shape_list(memory) - # init states - outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, attention = self.step(prenet_next, states, None, training=False) - stop_token = tf.math.sigmoid(stop_token) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag - - cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - _, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop( - cond, - _body, - loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps, - ) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference_tflite(self, memory, states): - """Inference with TF-Lite compatibility. It assumes - batch_size is 1""" - # init states - # dynamic_shape is not supported in TFLite - outputs = tf.TensorArray( - dtype=tf.float32, - size=self.max_decoder_steps, - element_shape=tf.TensorShape([self.output_dim]), - clear_after_read=False, - dynamic_size=False, - ) - # stop_flags = tf.TensorArray(dtype=tf.bool, - # size=self.max_decoder_steps, - # element_shape=tf.TensorShape( - # []), - # clear_after_read=False, - # dynamic_size=False) - attentions = () - stop_tokens = () - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, _ = self.step(prenet_next, states, None, training=False) - stop_token = tf.math.sigmoid(stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) - - outputs = outputs.write(step, tf.reshape(output, [-1])) - return step + 1, memory, states, outputs, stop_flag - - cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - step_count, memory, states, outputs, stop_flag = tf.while_loop( - cond, - _body, - loop_vars=(step_count, memory, states, outputs, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps, - ) - - outputs = outputs.stack() - outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter - outputs = tf.expand_dims(outputs, axis=[0]) - outputs = tf.transpose(outputs, [1, 0, 2]) - outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def call(self, memory, states, frames=None, memory_seq_length=None, training=False): - if training: - return self.decode(memory, states, frames, memory_seq_length) - if self.enable_tflite: - return self.decode_inference_tflite(memory, states) - return self.decode_inference(memory, states) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py deleted file mode 100644 index 7a1d695d..00000000 --- a/TTS/tts/tf/models/tacotron2.py +++ /dev/null @@ -1,116 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet -from TTS.tts.tf.utils.tf_utils import shape_list - - -# pylint: disable=too-many-ancestors, abstract-method -class Tacotron2(keras.models.Model): - def __init__( - self, - num_chars, - num_speakers, - r, - out_channels=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="softmax", - attn_K=4, - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=False, - ): - super().__init__() - self.r = r - self.decoder_output_dim = decoder_output_dim - self.out_channels = out_channels - self.bidirectional_decoder = bidirectional_decoder - self.num_speakers = num_speakers - self.speaker_embed_dim = 256 - self.enable_tflite = enable_tflite - - self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding") - self.encoder = Encoder(512, name="encoder") - # TODO: most of the decoder args have no use at the momment - self.decoder = Decoder( - decoder_output_dim, - r, - attn_type=attn_type, - use_attn_win=attn_win, - attn_norm=attn_norm, - prenet_type=prenet_type, - prenet_dropout=prenet_dropout, - use_forward_attn=forward_attn, - use_trans_agent=trans_agent, - use_forward_attn_mask=forward_attn_mask, - use_location_attn=location_attn, - attn_K=attn_K, - separate_stopnet=separate_stopnet, - speaker_emb_dim=self.speaker_embed_dim, - name="decoder", - enable_tflite=enable_tflite, - ) - self.postnet = Postnet(out_channels, 5, name="postnet") - - @tf.function(experimental_relax_shapes=True) - def call(self, characters, text_lengths=None, frames=None, training=None): - if training: - return self.training(characters, text_lengths, frames) - if not training: - return self.inference(characters) - raise RuntimeError(" [!] Set model training mode True or False") - - def training(self, characters, text_lengths, frames): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=True) - encoder_output = self.encoder(embedding_vectors, training=True) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder( - encoder_output, decoder_states, frames, text_lengths, training=True - ) - postnet_frames = self.postnet(decoder_frames, training=True) - output_frames = decoder_frames + postnet_frames - return decoder_frames, output_frames, attentions, stop_tokens - - def inference(self, characters): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=False) - encoder_output = self.encoder(embedding_vectors, training=False) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) - postnet_frames = self.postnet(decoder_frames, training=False) - output_frames = decoder_frames + postnet_frames - print(output_frames.shape) - return decoder_frames, output_frames, attentions, stop_tokens - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, None], dtype=tf.int32), - ], - ) - def inference_tflite(self, characters): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=False) - encoder_output = self.encoder(embedding_vectors, training=False) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) - postnet_frames = self.postnet(decoder_frames, training=False) - output_frames = decoder_frames + postnet_frames - print(output_frames.shape) - return decoder_frames, output_frames, attentions, stop_tokens - - def build_inference( - self, - ): - # TODO: issue https://github.com/PyCQA/pylint/issues/3613 - input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg - self(input_ids) diff --git a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py deleted file mode 100644 index 2c615a7d..00000000 --- a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -import tensorflow as tf - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg - - -def tf_create_dummy_inputs(): - """Create dummy inputs for TF Tacotron2 model""" - batch_size = 4 - max_input_length = 32 - max_mel_length = 128 - pad = 1 - n_chars = 24 - input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) - input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size]) - input_lengths[-1] = max_input_length - input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) - mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) - mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size]) - mel_lengths[-1] = max_mel_length - mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) - return input_ids, input_lengths, mel_outputs, mel_lengths - - -def compare_torch_tf(torch_tensor, tf_tensor): - """Compute the average absolute difference b/w torch and tf tensors""" - return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() - - -def convert_tf_name(tf_name): - """Convert certain patterns in TF layer names to Torch patterns""" - tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(":0", "") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") - tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") - tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") - tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") - tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") - tf_name_tmp = tf_name_tmp.replace("/", ".") - return tf_name_tmp - - -def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): - """Transfer weigths from torch state_dict to TF variables""" - print(" > Passing weights from Torch to TF ...") - for tf_var in tf_vars: - torch_var_name = var_map_dict[tf_var.name] - print(f" | > {tf_var.name} <-- {torch_var_name}") - # if tuple, it is a bias variable - if not isinstance(torch_var_name, tuple): - torch_layer_name = ".".join(torch_var_name.split(".")[-2:]) - torch_weight = state_dict[torch_var_name] - if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name: - # out_dim, in_dim, filter -> filter, in_dim, out_dim - numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() - elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - # if variable is for bidirectional lstm and it is a bias vector there - # needs to be pre-defined two matching torch bias vectors - elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name: - bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] - assert len(bias_vectors) == 2 - numpy_weight = bias_vectors[0] + bias_vectors[1] - elif "rnn" in tf_var.name and "kernel" in tf_var.name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - elif "rnn" in tf_var.name and "bias" in tf_var.name: - bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] - assert len(bias_vectors) == 2 - numpy_weight = bias_vectors[0] + bias_vectors[1] - elif "linear_layer" in torch_layer_name and "weight" in torch_var_name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - else: - numpy_weight = torch_weight.detach().cpu().numpy() - assert np.all( - tf_var.shape == numpy_weight.shape - ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" - tf.keras.backend.set_value(tf_var, numpy_weight) - return tf_vars - - -def load_tf_vars(model_tf, tf_vars): - for tf_var in tf_vars: - model_tf.get_layer(tf_var.name).set_weights(tf_var) - return model_tf diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py deleted file mode 100644 index 681a9457..00000000 --- a/TTS/tts/tf/utils/generic_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -import datetime -import importlib -import pickle - -import fsspec -import numpy as np -import tensorflow as tf - - -def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): - state = { - "model": model.weights, - "optimizer": optimizer, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - "r": r, - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - try: - chkp_var_value = chkp_var_dict[layer_name] - except KeyError: - class_name = list(chkp_var_dict.keys())[0].split("/")[0] - layer_name = f"{class_name}/{layer_name}" - chkp_var_value = chkp_var_dict[layer_name] - - tf.keras.backend.set_value(tf_var, chkp_var_value) - if "r" in checkpoint.keys(): - model.decoder.set_r(checkpoint["r"]) - return model - - -def sequence_mask(sequence_length, max_len=None): - if max_len is None: - max_len = sequence_length.max() - batch_size = sequence_length.size(0) - seq_range = np.empty([0, max_len], dtype=np.int8) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_range_expand = seq_range_expand.type_as(sequence_length) - seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) - # B x T_max - return seq_range_expand < seq_length_expand - - -# @tf.custom_gradient -def check_gradient(x, grad_clip): - x_normed = tf.clip_by_norm(x, grad_clip) - grad_norm = tf.norm(grad_clip) - return x_normed, grad_norm - - -def count_parameters(model, c): - try: - return model.count_params() - except RuntimeError: - input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32")) - input_lengths = np.random.randint(100, 129, (8,)) - input_lengths[-1] = 128 - input_lengths = tf.convert_to_tensor(input_lengths.astype("int32")) - mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32") - mel_spec = tf.convert_to_tensor(mel_spec) - speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None - _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids) - return model.count_params() - - -def setup_model(num_chars, num_speakers, c, enable_tflite=False): - print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower()) - MyModel = getattr(MyModel, c.model) - if c.model.lower() in "tacotron": - raise NotImplementedError(" [!] Tacotron model is not ready.") - # tacotron2 - model = MyModel( - num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - out_channels=c.audio["num_mels"], - decoder_output_dim=c.audio["num_mels"], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - enable_tflite=enable_tflite, - ) - return model diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py deleted file mode 100644 index de6acff9..00000000 --- a/TTS/tts/tf/utils/io.py +++ /dev/null @@ -1,45 +0,0 @@ -import datetime -import pickle - -import fsspec -import tensorflow as tf - - -def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): - state = { - "model": model.weights, - "optimizer": optimizer, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - "r": r, - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - try: - chkp_var_value = chkp_var_dict[layer_name] - except KeyError: - class_name = list(chkp_var_dict.keys())[0].split("/")[0] - layer_name = f"{class_name}/{layer_name}" - chkp_var_value = chkp_var_dict[layer_name] - - tf.keras.backend.set_value(tf_var, chkp_var_value) - if "r" in checkpoint.keys(): - model.decoder.set_r(checkpoint["r"]) - return model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/TTS/tts/tf/utils/tf_utils.py b/TTS/tts/tf/utils/tf_utils.py deleted file mode 100644 index 558936d5..00000000 --- a/TTS/tts/tf/utils/tf_utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import tensorflow as tf - - -def shape_list(x): - """Deal with dynamic shape in tensorflow cleanly.""" - static = x.shape.as_list() - dynamic = tf.shape(x) - return [dynamic[i] if s is None else s for i, s in enumerate(static)] diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py deleted file mode 100644 index 2f76aa50..00000000 --- a/TTS/tts/tf/utils/tflite.py +++ /dev/null @@ -1,27 +0,0 @@ -import fsspec -import tensorflow as tf - - -def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True): - """Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is - provided, else return TFLite model.""" - - concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) - converter.experimental_new_converter = experimental_converter - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite_model = converter.convert() - print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") - if output_path is not None: - # same model binary if outputpath is provided - with fsspec.open(output_path, "wb") as f: - f.write(tflite_model) - return None - return tflite_model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 24b747be..b2ea4208 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,19 +1,11 @@ -import os from typing import Dict import numpy as np -import pkg_resources import torch from torch import nn from .text import phoneme_to_sequence, text_to_sequence -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable -if "tensorflow" in installed or "tensorflow-gpu" in installed: - import tensorflow as tf - def text_to_seq(text, CONFIG, custom_symbols=None, language=None): text_cleaner = [CONFIG.text_cleaner] @@ -51,13 +43,6 @@ def numpy_to_torch(np_array, dtype, cuda=False): return tensor -def numpy_to_tf(np_array, dtype): - if np_array is None: - return None - tensor = tf.convert_to_tensor(np_array, dtype=dtype) - return tensor - - def compute_style_mel(style_wav, ap, cuda=False): style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) if cuda: @@ -103,53 +88,6 @@ def run_model_torch( return outputs -def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None): - if CONFIG.gst and style_mel is not None: - raise NotImplementedError(" [!] GST inference not implemented for TF") - if speaker_id is not None: - raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") - # TODO: handle multispeaker case - decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False) - return decoder_output, postnet_output, alignments, stop_tokens - - -def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None): - if CONFIG.gst and style_mel is not None: - raise NotImplementedError(" [!] GST inference not implemented for TfLite") - if speaker_id is not None: - raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite") - # get input and output details - input_details = model.get_input_details() - output_details = model.get_output_details() - # reshape input tensor for the new input shape - model.resize_tensor_input(input_details[0]["index"], inputs.shape) - model.allocate_tensors() - detail = input_details[0] - # input_shape = detail['shape'] - model.set_tensor(detail["index"], inputs) - # run the model - model.invoke() - # collect outputs - decoder_output = model.get_tensor(output_details[0]["index"]) - postnet_output = model.get_tensor(output_details[1]["index"]) - # tflite model only returns feature frames - return decoder_output, postnet_output, None, None - - -def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): - postnet_output = postnet_output[0].numpy() - decoder_output = decoder_output[0].numpy() - alignment = alignments[0].numpy() - stop_tokens = stop_tokens[0].numpy() - return postnet_output, decoder_output, alignment, stop_tokens - - -def parse_outputs_tflite(postnet_output, decoder_output): - postnet_output = postnet_output[0] - decoder_output = decoder_output[0] - return postnet_output, decoder_output - - def trim_silence(wav, ap): return wav[: ap.find_endpoint(wav)] @@ -213,7 +151,6 @@ def synthesis( d_vector=None, language_id=None, language_name=None, - backend="torch", ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -254,9 +191,6 @@ def synthesis( language_name (str): Language name corresponding to the language code used by the phonemizer. Defaults to None. - - backend (str): - tf or torch. Defaults to "torch". """ # GST processing style_mel = None @@ -270,44 +204,27 @@ def synthesis( custom_symbols = model.make_symbols(CONFIG) # preprocess the given text text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name) - # pass tensors to backend - if backend == "torch": - if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) - if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) - if language_id is not None: - language_id = id_to_torch(language_id, cuda=use_cuda) + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + + if language_id is not None: + language_id = id_to_torch(language_id, cuda=use_cuda) + + if not isinstance(style_mel, dict): + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = text_inputs.unsqueeze(0) - if not isinstance(style_mel, dict): - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) - text_inputs = text_inputs.unsqueeze(0) - elif backend in ["tf", "tflite"]: - # TODO: handle speaker id for tf model - style_mel = numpy_to_tf(style_mel, tf.float32) - text_inputs = numpy_to_tf(text_inputs, tf.int32) - text_inputs = tf.expand_dims(text_inputs, 0) # synthesize voice - if backend == "torch": - outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) - model_outputs = outputs["model_outputs"] - model_outputs = model_outputs[0].data.cpu().numpy() - alignments = outputs["alignments"] - elif backend == "tf": - decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, text_inputs, CONFIG, speaker_id, style_mel - ) - model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf( - postnet_output, decoder_output, alignments, stop_tokens - ) - elif backend == "tflite": - decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite( - model, text_inputs, CONFIG, speaker_id, style_mel - ) - model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output) + outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) + model_outputs = outputs["model_outputs"] + model_outputs = model_outputs[0].data.cpu().numpy() + alignments = outputs["alignments"] + # convert outputs to numpy # plot results wav = None diff --git a/TTS/vocoder/tf/layers/melgan.py b/TTS/vocoder/tf/layers/melgan.py deleted file mode 100644 index 90bce6f1..00000000 --- a/TTS/vocoder/tf/layers/melgan.py +++ /dev/null @@ -1,54 +0,0 @@ -import tensorflow as tf - - -class ReflectionPad1d(tf.keras.layers.Layer): - def __init__(self, padding): - super().__init__() - self.padding = padding - - def call(self, x): - return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT") - - -class ResidualStack(tf.keras.layers.Layer): - def __init__(self, channels, num_res_blocks, kernel_size, name): - super().__init__(name=name) - - assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." - base_padding = (kernel_size - 1) // 2 - - self.blocks = [] - num_layers = 2 - for idx in range(num_res_blocks): - layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size ** idx - layer_padding = base_padding * layer_dilation - block = [ - tf.keras.layers.LeakyReLU(0.2), - ReflectionPad1d(layer_padding), - tf.keras.layers.Conv2D( - filters=channels, - kernel_size=(kernel_size, 1), - dilation_rate=(layer_dilation, 1), - use_bias=True, - padding="valid", - name=f"blocks.{idx}.{num_layers}", - ), - tf.keras.layers.LeakyReLU(0.2), - tf.keras.layers.Conv2D( - filters=channels, kernel_size=(1, 1), use_bias=True, name=f"blocks.{idx}.{num_layers + 2}" - ), - ] - self.blocks.append(block) - self.shortcuts = [ - tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f"shortcuts.{i}") - for i in range(num_res_blocks) - ] - - def call(self, x): - for block, shortcut in zip(self.blocks, self.shortcuts): - res = shortcut(x) - for layer in block: - x = layer(x) - x += res - return x diff --git a/TTS/vocoder/tf/layers/pqmf.py b/TTS/vocoder/tf/layers/pqmf.py deleted file mode 100644 index 042f2f08..00000000 --- a/TTS/vocoder/tf/layers/pqmf.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -import tensorflow as tf -from scipy import signal as sig - - -class PQMF(tf.keras.layers.Layer): - def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): - super().__init__() - # define filter coefficient - self.N = N - self.taps = taps - self.cutoff = cutoff - self.beta = beta - - QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) - H = np.zeros((N, len(QMF))) - G = np.zeros((N, len(QMF))) - for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - phase = (-1) ** k * np.pi / 4 - H[k] = 2 * QMF * np.cos(constant_factor + phase) - - G[k] = 2 * QMF * np.cos(constant_factor - phase) - - # [N, 1, taps + 1] == [filter_width, in_channels, out_channels] - self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype("float32") - self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype("float32") - - # filter for downsampling & upsampling - updown_filter = np.zeros((N, N, N), dtype=np.float32) - for k in range(N): - updown_filter[0, k, k] = 1.0 - self.updown_filter = updown_filter.astype(np.float32) - - def analysis(self, x): - """ - x : :math:`[B, 1, T]` - """ - x = tf.transpose(x, perm=[0, 2, 1]) - x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.H, stride=1, padding="VALID") - x = tf.nn.conv1d(x, self.updown_filter, stride=self.N, padding="VALID") - x = tf.transpose(x, perm=[0, 2, 1]) - return x - - def synthesis(self, x): - """ - x : B x D x T - """ - x = tf.transpose(x, perm=[0, 2, 1]) - x = tf.nn.conv1d_transpose( - x, - self.updown_filter * self.N, - strides=self.N, - output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, self.N), - ) - x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID") - x = tf.transpose(x, perm=[0, 2, 1]) - return x diff --git a/TTS/vocoder/tf/models/melgan_generator.py b/TTS/vocoder/tf/models/melgan_generator.py deleted file mode 100644 index 09ee9530..00000000 --- a/TTS/vocoder/tf/models/melgan_generator.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -import os - -import tensorflow as tf - -from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL -logging.getLogger("tensorflow").setLevel(logging.FATAL) - -from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack - - -# pylint: disable=too-many-ancestors -# pylint: disable=abstract-method -class MelganGenerator(tf.keras.models.Model): - """Melgan Generator TF implementation dedicated for inference with no - weight norm""" - - def __init__( - self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3, - ): - super().__init__() - - self.in_channels = in_channels - - # assert model parameters - assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." - - # setup additional model parameters - base_padding = (proj_kernel - 1) // 2 - act_slope = 0.2 - self.inference_padding = 2 - - # initial layer - self.initial_layer = [ - ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D( - filters=base_channels, kernel_size=(proj_kernel, 1), strides=1, padding="valid", use_bias=True, name="1" - ), - ] - num_layers = 3 # count number of layers for layer naming - - # upsampling layers and residual stacks - self.upsample_layers = [] - for idx, upsample_factor in enumerate(upsample_factors): - layer_out_channels = base_channels // (2 ** (idx + 1)) - layer_filter_size = upsample_factor * 2 - layer_stride = upsample_factor - # layer_output_padding = upsample_factor % 2 - self.upsample_layers += [ - tf.keras.layers.LeakyReLU(act_slope), - tf.keras.layers.Conv2DTranspose( - filters=layer_out_channels, - kernel_size=(layer_filter_size, 1), - strides=(layer_stride, 1), - padding="same", - # output_padding=layer_output_padding, - use_bias=True, - name=f"{num_layers}", - ), - ResidualStack( - channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel, - name=f"layers.{num_layers + 1}", - ), - ] - num_layers += num_res_blocks - 1 - - self.upsample_layers += [tf.keras.layers.LeakyReLU(act_slope)] - - # final layer - self.final_layers = [ - ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D( - filters=out_channels, kernel_size=(proj_kernel, 1), use_bias=True, name=f"layers.{num_layers + 1}" - ), - tf.keras.layers.Activation("tanh"), - ] - - # self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers") - self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers - - @tf.function(experimental_relax_shapes=True) - def call(self, c, training=False): - """ - c : :math:`[B, C, T]` - """ - if training: - raise NotImplementedError() - return self.inference(c) - - def inference(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - # o = self.model_layers(c) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - return o[:, :, 0, :] - - def build_inference(self): - x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32) - self(x, training=False) - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, None, None], dtype=tf.float32), - ], - ) - def inference_tflite(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - # o = self.model_layers(c) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - return o[:, :, 0, :] diff --git a/TTS/vocoder/tf/models/multiband_melgan_generator.py b/TTS/vocoder/tf/models/multiband_melgan_generator.py deleted file mode 100644 index 24d899b2..00000000 --- a/TTS/vocoder/tf/models/multiband_melgan_generator.py +++ /dev/null @@ -1,65 +0,0 @@ -import tensorflow as tf - -from TTS.vocoder.tf.layers.pqmf import PQMF -from TTS.vocoder.tf.models.melgan_generator import MelganGenerator - - -# pylint: disable=too-many-ancestors -# pylint: disable=abstract-method -class MultibandMelganGenerator(MelganGenerator): - def __init__( - self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks, - ) - self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) - - def pqmf_analysis(self, x): - return self.pqmf_layer.analysis(x) - - def pqmf_synthesis(self, x): - return self.pqmf_layer.synthesis(x) - - def inference(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - o = self.pqmf_layer.synthesis(o[:, :, 0, :]) - return o - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, 80, None], dtype=tf.float32), - ], - ) - def inference_tflite(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - o = self.pqmf_layer.synthesis(o[:, :, 0, :]) - return o diff --git a/TTS/vocoder/tf/utils/__init__.py b/TTS/vocoder/tf/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py deleted file mode 100644 index 453d8b78..00000000 --- a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np -import tensorflow as tf - - -def compare_torch_tf(torch_tensor, tf_tensor): - """Compute the average absolute difference b/w torch and tf tensors""" - return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() - - -def convert_tf_name(tf_name): - """Convert certain patterns in TF layer names to Torch patterns""" - tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(":0", "") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") - tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") - tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") - tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") - tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") - tf_name_tmp = tf_name_tmp.replace("/", ".") - return tf_name_tmp - - -def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): - """Transfer weigths from torch state_dict to TF variables""" - print(" > Passing weights from Torch to TF ...") - for tf_var in tf_vars: - torch_var_name = var_map_dict[tf_var.name] - print(f" | > {tf_var.name} <-- {torch_var_name}") - # if tuple, it is a bias variable - if "kernel" in tf_var.name: - torch_weight = state_dict[torch_var_name] - numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :] - if "bias" in tf_var.name: - torch_weight = state_dict[torch_var_name] - numpy_weight = torch_weight - assert np.all( - tf_var.shape == numpy_weight.shape - ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" - tf.keras.backend.set_value(tf_var, numpy_weight) - return tf_vars - - -def load_tf_vars(model_tf, tf_vars): - for tf_var in tf_vars: - model_tf.get_layer(tf_var.name).set_weights(tf_var) - return model_tf diff --git a/TTS/vocoder/tf/utils/generic_utils.py b/TTS/vocoder/tf/utils/generic_utils.py deleted file mode 100644 index 94364ab4..00000000 --- a/TTS/vocoder/tf/utils/generic_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -import importlib -import re - - -def to_camel(text): - text = text.capitalize() - return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) - - -def setup_generator(c): - print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module("TTS.vocoder.tf.models." + c.generator_model.lower()) - MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model in "melgan_generator": - model = MyModel( - in_channels=c.audio["num_mels"], - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=c.generator_model_params["upsample_factors"], - res_kernel=3, - num_res_blocks=c.generator_model_params["num_res_blocks"], - ) - if c.generator_model in "melgan_fb_generator": - pass - if c.generator_model in "multiband_melgan_generator": - model = MyModel( - in_channels=c.audio["num_mels"], - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=c.generator_model_params["upsample_factors"], - res_kernel=3, - num_res_blocks=c.generator_model_params["num_res_blocks"], - ) - return model diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py deleted file mode 100644 index 3de8adab..00000000 --- a/TTS/vocoder/tf/utils/io.py +++ /dev/null @@ -1,31 +0,0 @@ -import datetime -import pickle - -import fsspec -import tensorflow as tf - - -def save_checkpoint(model, current_step, epoch, output_path, **kwargs): - """Save TF Vocoder model""" - state = { - "model": model.weights, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - """Load TF Vocoder model""" - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - chkp_var_value = chkp_var_dict[layer_name] - tf.keras.backend.set_value(tf_var, chkp_var_value) - return model diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py deleted file mode 100644 index 876739fd..00000000 --- a/TTS/vocoder/tf/utils/tflite.py +++ /dev/null @@ -1,27 +0,0 @@ -import fsspec -import tensorflow as tf - - -def convert_melgan_to_tflite(model, output_path=None, experimental_converter=True): - """Convert Tensorflow MelGAN model to TFLite. Save a binary file if output_path is - provided, else return TFLite model.""" - - concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) - converter.experimental_new_converter = experimental_converter - converter.optimizations = [] - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite_model = converter.convert() - print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") - if output_path is not None: - # same model binary if outputpath is provided - with fsspec.open(output_path, "wb") as f: - f.write(tflite_model) - return None - return tflite_model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/docs/source/converting_torch_to_tf.md b/docs/source/converting_torch_to_tf.md deleted file mode 100644 index 20a0be6b..00000000 --- a/docs/source/converting_torch_to_tf.md +++ /dev/null @@ -1,21 +0,0 @@ -# Converting Torch to TF 2 - -Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2. - -You can also export TF 2 models to TFLite for even faster inference. - -## How to convert from Torch to TF 2.0 -Make sure you installed Tensorflow v2.2. It is not installed by default by :frog: TTS. - -All the TF related code stays under ```tf``` folder. - -To convert a **compatible** Torch model, run the following command with the right arguments: - -```bash -python TTS/bin/convert_tacotron2_torch_to_tf.py\ - --torch_model_path /path/to/torch/model.pth.tar \ - --config_path /path/to/model/config.json\ - --output_path /path/to/output/tf/model -``` - -This will create a TF model file. Notice that our model format is not compatible with the official TF checkpoints. We created our custom format to match Torch checkpoints we use. Therefore, use the ```load_checkpoint``` and ```save_checkpoint``` functions provided under ```TTS.tf.generic_utils```. diff --git a/docs/source/index.md b/docs/source/index.md index 756cea8e..9dc5bfce 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -27,7 +27,6 @@ formatting_your_dataset what_makes_a_good_dataset tts_datasets - converting_torch_to_tf .. toctree:: :maxdepth: 2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 6532ee8e..0122271d 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -12,12 +12,6 @@ You can install from PyPI as follows: pip install TTS # from PyPI ``` -By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra. - -```bash -pip install TTS[tf] -``` - Or install from Github: ```bash diff --git a/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb b/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb deleted file mode 100644 index 8a25132c..00000000 --- a/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb +++ /dev/null @@ -1,425 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "6LWsNd3_M3MP" - }, - "source": [ - "# Converting Pytorch models to Tensorflow and TFLite by CoquiTTS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FAqrSIWgLyP0" - }, - "source": [ - "This is a tutorial demonstrating Coqui TTS capabilities to convert \n", - "trained PyTorch models to Tensorflow and Tflite.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MBJjGYnoEo4v" - }, - "source": [ - "# Installation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Ku-dA4DKoeXk" - }, - "source": [ - "### Download TF Models and configs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 162 - }, - "colab_type": "code", - "id": "jGIgnWhGsxU1", - "outputId": "b461952f-8507-4dd2-af06-4e6b8692765d", - "tags": [] - }, - "outputs": [], - "source": [ - "!gdown --id 1dntzjWFg7ufWaTaFy80nRz-Tu02xWZos -O data/tts_model.pth.tar\n", - "!gdown --id 18CQ6G6tBEOfvCHlPqP8EBI4xWbrr9dBc -O data/config.json" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 235 - }, - "colab_type": "code", - "id": "4dnpE0-kvTsu", - "outputId": "f67c3138-bda0-4b3e-ffcc-647f9feec23e", - "tags": [] - }, - "outputs": [], - "source": [ - "!gdown --id 1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K -O data/vocoder_model.pth.tar\n", - "!gdown --id 1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu -O data/config_vocoder.json\n", - "!gdown --id 11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU -O data/scale_stats.npy" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3IGvvCRMEwqn" - }, - "source": [ - "# Model Conversion PyTorch -> TF -> TFLite" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "tLhz8SAf8Pgp" - }, - "source": [ - "## Converting PyTorch to Tensorflow\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "colab_type": "code", - "id": "Xsrvr_WQ8Ib5", - "outputId": "dae96616-e5f7-41b6-cdb9-5026cfcd3214", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert TTS model to Tensorflow\n", - "!python ../TTS/bin/convert_tacotron2_torch_to_tf.py --config_path data/config.json --torch_model_path data/tts_model.pth.tar --output_path data/tts_model_tf.pkl" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "colab_type": "code", - "id": "VJ4NA5If9ljv", - "outputId": "1520dca8-1db8-4e07-bc0c-b1d5941c775e", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert Vocoder model to Tensorflow\n", - "!python ../TTS/bin/convert_melgan_torch_to_tf.py --config_path data/config_vocoder.json --torch_model_path data/vocoder_model.pth.tar --output_path data/vocoder_model_tf.pkl" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7d5vTkBZ-BYQ" - }, - "source": [ - "## Converting Tensorflow to TFLite" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 927 - }, - "colab_type": "code", - "id": "33hTfpuU99cg", - "outputId": "8a0e5be1-23a2-4128-ee37-8232adcb8ff0", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert TTS model to TFLite\n", - "!python ../TTS/bin/convert_tacotron2_tflite.py --config_path data/config.json --tf_model data/tts_model_tf.pkl --output_path data/tts_model.tflite" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 364 - }, - "colab_type": "code", - "id": "e00Hm75Y-wZ2", - "outputId": "42381b05-3c9d-44f0-dac7-d81efd95eadf", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert Vocoder model to TFLite\n", - "!python ../TTS/bin/convert_melgan_tflite.py --config_path data/config_vocoder.json --tf_model data/vocoder_model_tf.pkl --output_path data/vocoder_model.tflite" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Zlgi8fPdpRF0" - }, - "source": [ - "# Run Inference with TFLite " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "f-Yc42nQZG5A" - }, - "outputs": [], - "source": [ - "def run_vocoder(mel_spec):\n", - " vocoder_inputs = mel_spec[None, :, :]\n", - " # get input and output details\n", - " input_details = vocoder_model.get_input_details()\n", - " # reshape input tensor for the new input shape\n", - " vocoder_model.resize_tensor_input(input_details[0]['index'], vocoder_inputs.shape)\n", - " vocoder_model.allocate_tensors()\n", - " detail = input_details[0]\n", - " vocoder_model.set_tensor(detail['index'], vocoder_inputs)\n", - " # run the model\n", - " vocoder_model.invoke()\n", - " # collect outputs\n", - " output_details = vocoder_model.get_output_details()\n", - " waveform = vocoder_model.get_tensor(output_details[0]['index'])\n", - " return waveform \n", - "\n", - "\n", - "def tts(model, text, CONFIG, p):\n", - " t_1 = time.time()\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None,\n", - " truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n", - " backend='tflite')\n", - " waveform = run_vocoder(mel_postnet_spec.T)\n", - " waveform = waveform[0, 0]\n", - " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", - " tps = (time.time() - t_1) / len(waveform)\n", - " print(waveform.shape)\n", - " print(\" > Run-time: {}\".format(time.time() - t_1))\n", - " print(\" > Real-time factor: {}\".format(rtf))\n", - " print(\" > Time per step: {}\".format(tps))\n", - " IPython.display.display(IPython.display.Audio(waveform, rate=CONFIG.audio['sample_rate'])) \n", - " return alignment, mel_postnet_spec, stop_tokens, waveform" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ZksegYQepkFg" - }, - "source": [ - "### Load TF Models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oVa0kOamprgj" - }, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "import time\n", - "import IPython\n", - "\n", - "from TTS.tts.tf.utils.tflite import load_tflite_model\n", - "from TTS.tts.tf.utils.io import load_checkpoint\n", - "from TTS.utils.io import load_config\n", - "from TTS.tts.utils.text.symbols import symbols, phonemes\n", - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.tts.utils.synthesis import synthesis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "EY-sHVO8IFSH" - }, - "outputs": [], - "source": [ - "# runtime settings\n", - "use_cuda = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "_1aIUp2FpxOQ" - }, - "outputs": [], - "source": [ - "# model paths\n", - "TTS_MODEL = \"data/tts_model.tflite\"\n", - "TTS_CONFIG = \"data/config.json\"\n", - "VOCODER_MODEL = \"data/vocoder_model.tflite\"\n", - "VOCODER_CONFIG = \"data/config_vocoder.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "CpgmdBVQplbv" - }, - "outputs": [], - "source": [ - "# load configs\n", - "TTS_CONFIG = load_config(TTS_CONFIG)\n", - "VOCODER_CONFIG = load_config(VOCODER_CONFIG)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 471 - }, - "colab_type": "code", - "id": "zmrQxiozIUVE", - "outputId": "21cda136-de87-4d55-fd46-7d5306103d90", - "tags": [] - }, - "outputs": [], - "source": [ - "# load the audio processor\n", - "TTS_CONFIG.audio['stats_path'] = 'data/scale_stats.npy'\n", - "ap = AudioProcessor(**TTS_CONFIG.audio) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "8fLoI4ipqMeS" - }, - "outputs": [], - "source": [ - "# LOAD TTS MODEL\n", - "# multi speaker \n", - "speaker_id = None\n", - "speakers = []\n", - "\n", - "# load the models\n", - "model = load_tflite_model(TTS_MODEL)\n", - "vocoder_model = load_tflite_model(VOCODER_MODEL)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Ws_YkPKsLgo-" - }, - "source": [ - "## Run Sample Sentence" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 134 - }, - "colab_type": "code", - "id": "FuWxZ9Ey5Puj", - "outputId": "535c2df1-c27c-458b-e14b-41a977635aa1", - "tags": [] - }, - "outputs": [], - "source": [ - "sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, TTS_CONFIG, ap)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/requirements.tf.txt b/requirements.tf.txt deleted file mode 100644 index 8e256a90..00000000 --- a/requirements.tf.txt +++ /dev/null @@ -1 +0,0 @@ -tensorflow==2.5.0 diff --git a/setup.py b/setup.py index 95f0841b..1d4dbf1c 100644 --- a/setup.py +++ b/setup.py @@ -65,9 +65,7 @@ with open(os.path.join(cwd, "requirements.notebooks.txt"), "r") as f: requirements_notebooks = f.readlines() with open(os.path.join(cwd, "requirements.dev.txt"), "r") as f: requirements_dev = f.readlines() -with open(os.path.join(cwd, "requirements.tf.txt"), "r") as f: - requirements_tf = f.readlines() -requirements_all = requirements_dev + requirements_notebooks + requirements_tf +requirements_all = requirements_dev + requirements_notebooks with open("README.md", "r", encoding="utf-8") as readme_file: README = readme_file.read() @@ -116,7 +114,6 @@ setup( "all": requirements_all, "dev": requirements_dev, "notebooks": requirements_notebooks, - "tf": requirements_tf, }, python_requires=">=3.6.0, <3.10", entry_points={"console_scripts": ["tts=TTS.bin.synthesize:main", "tts-server = TTS.server.server:main"]}, diff --git a/tests/tts_tests/test_tacotron2_tf_model.py b/tests/tts_tests/test_tacotron2_tf_model.py deleted file mode 100644 index fb1efcde..00000000 --- a/tests/tts_tests/test_tacotron2_tf_model.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -import unittest - -import numpy as np -import tensorflow as tf -import torch - -from TTS.tts.configs.tacotron2_config import Tacotron2Config -from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model - -tf.get_logger().setLevel("INFO") - - -# pylint: disable=unused-variable - -torch.manual_seed(1) -use_cuda = torch.cuda.is_available() -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -c = Tacotron2Config() - - -class TacotronTFTrainTest(unittest.TestCase): - @staticmethod - def generate_dummy_inputs(): - chars_seq = torch.randint(0, 24, (8, 128)).long().to(device) - chars_seq_lengths = torch.randint(100, 128, (8,)).long().to(device) - chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) - - chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy()) - chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy()) - mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) - return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_train_step(self): - """test forward pass""" - ( - chars_seq, - chars_seq_lengths, - mel_spec, - mel_postnet_spec, - mel_lengths, - stop_targets, - speaker_ids, - ) = self.generate_dummy_inputs() - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) - # training pass - output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) - - # check model output shapes - assert np.all(output[0].shape == mel_spec.shape) - assert np.all(output[1].shape == mel_spec.shape) - assert output[2].shape[2] == chars_seq.shape[1] - assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) - assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) - - # inference pass - output = model(chars_seq, training=False) - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_forward_attention( - self, - ): - ( - chars_seq, - chars_seq_lengths, - mel_spec, - mel_postnet_spec, - mel_lengths, - stop_targets, - speaker_ids, - ) = self.generate_dummy_inputs() - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True) - # training pass - output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) - - # check model output shapes - assert np.all(output[0].shape == mel_spec.shape) - assert np.all(output[1].shape == mel_spec.shape) - assert output[2].shape[2] == chars_seq.shape[1] - assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) - assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) - - # inference pass - output = model(chars_seq, training=False) - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_tflite_conversion( - self, - ): # pylint:disable=no-self-use - model = Tacotron2( - num_chars=24, - num_speakers=0, - r=3, - out_channels=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=0, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=True, - ) - model.build_inference() - convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True) - # init tflite model - tflite_model = load_tflite_model("test_tacotron2.tflite") - # fake input - inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) # pylint:disable=unexpected-keyword-arg - # run inference - # get input and output details - input_details = tflite_model.get_input_details() - output_details = tflite_model.get_output_details() - # reshape input tensor for the new input shape - tflite_model.resize_tensor_input( - input_details[0]["index"], inputs.shape - ) # pylint:disable=unexpected-keyword-arg - tflite_model.allocate_tensors() - detail = input_details[0] - input_shape = detail["shape"] - tflite_model.set_tensor(detail["index"], inputs) - # run the tflite_model - tflite_model.invoke() - # collect outputs - decoder_output = tflite_model.get_tensor(output_details[0]["index"]) - postnet_output = tflite_model.get_tensor(output_details[1]["index"]) - # remove tflite binary - os.remove("test_tacotron2.tflite") diff --git a/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py b/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py deleted file mode 100644 index 225ceaf5..00000000 --- a/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest - -import numpy as np -import tensorflow as tf -import torch - -from TTS.vocoder.tf.models.melgan_generator import MelganGenerator - -use_cuda = torch.cuda.is_available() - - -@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.") -def test_melgan_generator(): - hop_length = 256 - model = MelganGenerator() - # pylint: disable=no-value-for-parameter - dummy_input = tf.random.uniform((4, 80, 64)) - output = model(dummy_input, training=False) - assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape diff --git a/tests/vocoder_tests/test_vocoder_tf_pqmf.py b/tests/vocoder_tests/test_vocoder_tf_pqmf.py deleted file mode 100644 index 6acb20d9..00000000 --- a/tests/vocoder_tests/test_vocoder_tf_pqmf.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import unittest - -import soundfile as sf -import tensorflow as tf -import torch -from librosa.core import load - -from tests import get_tests_input_path, get_tests_output_path, get_tests_path -from TTS.vocoder.tf.layers.pqmf import PQMF - -TESTS_PATH = get_tests_path() -WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") -use_cuda = torch.cuda.is_available() - - -@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.") -def test_pqmf(): - w, sr = load(WAV_FILE) - - layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) - w, sr = load(WAV_FILE) - w2 = tf.convert_to_tensor(w[None, None, :]) - b2 = layer.analysis(w2) - w2_ = layer.synthesis(b2) - w2_ = w2.numpy() - - print(w2_.max()) - print(w2_.min()) - print(w2_.mean()) - sf.write(os.path.join(get_tests_output_path(), "tf_pqmf_output.wav"), w2_.flatten(), sr) From 5e3f499a69555eb1aaffefed79f0c132ef57d59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 11 Feb 2022 13:27:59 +0100 Subject: [PATCH 004/214] Fix #1187 (#1227) --- TTS/vocoder/configs/parallel_wavegan_config.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py index a89b1f3f..f536ba98 100644 --- a/TTS/vocoder/configs/parallel_wavegan_config.py +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -70,11 +70,11 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): lr_scheduler_gen (torch.optim.Scheduler): Learning rate scheduler for the generator. Defaults to `ExponentialLR`. lr_scheduler_gen_params (dict): - Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. lr_scheduler_disc (torch.optim.Scheduler): Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. lr_scheduler_dict_params (dict): - Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. """ model: str = "parallel_wavegan" @@ -124,7 +124,8 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): lr_disc: float = 0.0002 # Initial learning rate. optimizer: str = "AdamW" optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) - lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html - lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) - lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html - lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + scheduler_after_epoch: bool = False From 127118c6378168e3d36a1e5d19ede777fd20684f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 11 Feb 2022 23:03:43 +0100 Subject: [PATCH 005/214] Update TTS.tts formatters (#1228) * Return Dict from tts formatters * Make style --- TTS/bin/compute_embeddings.py | 9 ++- TTS/bin/compute_statistics.py | 10 ++-- TTS/bin/find_unique_chars.py | 1 + TTS/bin/find_unique_phonemes.py | 5 ++ TTS/bin/train_tts.py | 1 + TTS/speaker_encoder/dataset.py | 10 ++-- TTS/speaker_encoder/models/resnet.py | 2 +- TTS/speaker_encoder/utils/generic_utils.py | 6 +- TTS/tts/datasets/__init__.py | 12 ++-- TTS/tts/datasets/dataset.py | 60 ++++++++----------- TTS/tts/datasets/formatters.py | 36 +++++------ TTS/tts/layers/generic/normalization.py | 2 +- TTS/tts/layers/generic/wavenet.py | 2 +- TTS/tts/layers/glow_tts/encoder.py | 2 +- TTS/tts/layers/glow_tts/transformer.py | 4 +- TTS/tts/layers/losses.py | 4 +- TTS/tts/layers/tacotron/gst_layers.py | 2 +- TTS/tts/layers/vits/networks.py | 2 +- .../vits/stochastic_duration_predictor.py | 6 +- TTS/tts/models/glow_tts.py | 8 +-- TTS/tts/models/vits.py | 22 ++++--- TTS/tts/utils/languages.py | 2 +- TTS/tts/utils/speakers.py | 4 +- TTS/tts/utils/ssim.py | 6 +- TTS/utils/audio.py | 24 ++++---- TTS/utils/download.py | 2 +- TTS/utils/training.py | 4 +- .../configs/parallel_wavegan_config.py | 4 +- TTS/vocoder/datasets/wavernn_dataset.py | 2 +- TTS/vocoder/layers/lvc_block.py | 4 +- TTS/vocoder/layers/melgan.py | 2 +- TTS/vocoder/layers/parallel_wavegan.py | 2 +- TTS/vocoder/models/hifigan_generator.py | 2 +- TTS/vocoder/models/melgan_generator.py | 2 +- .../models/parallel_wavegan_discriminator.py | 2 +- .../models/parallel_wavegan_generator.py | 2 +- TTS/vocoder/models/univnet_generator.py | 2 +- TTS/vocoder/models/wavegrad.py | 8 +-- TTS/vocoder/models/wavernn.py | 2 +- tests/data_tests/test_dataset_formatters.py | 10 ++-- tests/vocoder_tests/test_vocoder_wavernn.py | 2 +- 41 files changed, 153 insertions(+), 141 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 2ac18651..50817154 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -29,7 +29,9 @@ parser.add_argument( help="Path to dataset config file.", ) parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") -parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None) +parser.add_argument( + "--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None +) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True) @@ -41,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli wav_files = meta_data_train + meta_data_eval speaker_manager = SpeakerManager( - encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda + encoder_model_path=args.model_path, + encoder_config_path=args.config_path, + d_vectors_file_path=args.old_file, + use_cuda=args.use_cuda, ) # compute speaker embeddings diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index e1974ae7..3ab7ea7a 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -51,7 +51,7 @@ def main(): N = 0 for item in tqdm(dataset_items): # compute features - wav = ap.load_wav(item if isinstance(item, str) else item[1]) + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) linear = ap.spectrogram(wav) mel = ap.melspectrogram(wav) @@ -59,13 +59,13 @@ def main(): N += mel.shape[1] mel_sum += mel.sum(1) linear_sum += linear.sum(1) - mel_square_sum += (mel ** 2).sum(axis=1) - linear_square_sum += (linear ** 2).sum(axis=1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) mel_mean = mel_sum / N - mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2) + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) linear_mean = linear_sum / N - linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) output_file_path = args.out_path stats = {} diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 437c2d60..fb98bab5 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -24,6 +24,7 @@ def main(): # load all datasets train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + items = train_items + eval_items texts = "".join(item[0] for item in items) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index d3143ca3..02a783c7 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -43,6 +43,11 @@ def main(): items = train_items + eval_items print("Num items:", len(items)) + is_lang_def = all(item["language"] for item in items) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phones = [] for ph in phonemes: diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 0f8c4760..a7ce8ef3 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,4 +1,5 @@ import os + import torch from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 5b0fee22..28a23e2f 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset): mel = self.ap.melspectrogram(wav).astype("float32") # sample seq_len - assert text.size > 0, self.items[idx][1] - assert wav.size > 0, self.items[idx][1] + assert text.size > 0, self.items[idx]["audio_file"] + assert wav.size > 0, self.items[idx]["audio_file"] sample = { "mel": mel, - "item_idx": self.items[idx][1], + "item_idx": self.items[idx]["audio_file"], "speaker_name": speaker_name, } return sample @@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset): def __parse_items(self): self.speaker_to_utters = {} for i in self.items: - path_ = i[1] - speaker_ = i[2] + path_ = i["audio_file"] + speaker_ = i["speaker_name"] if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index d6c3dad4..a799fc52 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module): x = torch.sum(x * w, dim=2) elif self.encoder_type == "ASP": mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) x = torch.cat((mu, sg), 1) x = x.view(x.size()[0], -1) diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index b8aa4093..4ab4e923 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -113,7 +113,7 @@ class AugmentWAV(object): def additive_noise(self, noise_type, audio): - clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) noise_list = random.sample( self.noise_list[noise_type], @@ -135,7 +135,7 @@ class AugmentWAV(object): self.additive_noise_config[noise_type]["min_snr_in_db"], self.additive_noise_config[noise_type]["max_num_noises"], ) - noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio if noises_wav is None: @@ -154,7 +154,7 @@ class AugmentWAV(object): rir_file = random.choice(self.rir_files) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) - rir = rir / np.sqrt(np.sum(rir ** 2)) + rir = rir / np.sqrt(np.sum(rir**2)) return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] def apply_one(self, audio): diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 40eed7e3..455413fa 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -75,14 +75,14 @@ def load_tts_samples( formatter = _get_formatter_by_name(name) # load train set meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) - meta_data_train = [[*item, language] for item in meta_data_train] + meta_data_train = [{**item, **{"language": language}} for item in meta_data_train] print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [[*item, language] for item in meta_data_eval] + meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval @@ -91,12 +91,12 @@ def load_tts_samples( if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): - attn_file = meta_data[ins[1]].strip() - meta_data_train_all[idx].append(attn_file) + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): - attn_file = meta_data[ins[1]].strip() - meta_data_eval_all[idx].append(attn_file) + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) # set none for the next iter formatter = None return meta_data_train_all, meta_data_eval_all diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 2f20c865..546f012d 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -21,7 +21,7 @@ class TTSDataset(Dataset): text_cleaner: list, compute_linear_spec: bool, ap: AudioProcessor, - meta_data: List[List], + meta_data: List[Dict], compute_f0: bool = False, f0_cache_path: str = None, characters: Dict = None, @@ -54,7 +54,7 @@ class TTSDataset(Dataset): ap (TTS.tts.utils.AudioProcessor): Audio processor object. - meta_data (list): List of dataset instances. + meta_data (list): List of dataset samples. compute_f0 (bool): compute f0 if True. Defaults to False. @@ -199,15 +199,9 @@ class TTSDataset(Dataset): def load_data(self, idx): item = self.items[idx] + raw_text = item["text"] - if len(item) == 5: - text, wav_file, speaker_name, language_name, attn_file = item - else: - text, wav_file, speaker_name, language_name = item - attn = None - raw_text = text - - wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: @@ -216,12 +210,12 @@ class TTSDataset(Dataset): if not self.input_seq_computed: if self.use_phonemes: text = self._load_or_generate_phoneme_sequence( - wav_file, - text, + item["audio_file"], + item["text"], self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, - language_name if language_name else self.phoneme_language, + item["language"] if item["language"] else self.phoneme_language, self.custom_symbols, self.characters, self.add_blank, @@ -229,7 +223,7 @@ class TTSDataset(Dataset): else: text = np.asarray( text_to_sequence( - text, + item["text"], [self.cleaners], custom_symbols=self.custom_symbols, tp=self.characters, @@ -238,11 +232,12 @@ class TTSDataset(Dataset): dtype=np.int32, ) - assert text.size > 0, self.items[idx][1] - assert wav.size > 0, self.items[idx][1] + assert text.size > 0, self.items[idx]["audio_file"] + assert wav.size > 0, self.items[idx]["audio_file"] - if "attn_file" in locals(): - attn = np.load(attn_file) + attn = None + if "alignment_file" in item: + attn = np.load(item["alignment_file"]) if len(text) > self.max_seq_len: # return a different sample if the phonemized @@ -252,7 +247,7 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: - pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path) pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) sample = { @@ -261,10 +256,10 @@ class TTSDataset(Dataset): "wav": wav, "pitch": pitch, "attn": attn, - "item_idx": self.items[idx][1], - "speaker_name": speaker_name, - "language_name": language_name, - "wav_file_name": os.path.basename(wav_file), + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), } return sample @@ -272,11 +267,10 @@ class TTSDataset(Dataset): def _phoneme_worker(args): item = args[0] func_args = args[1] - text, wav_file, *_ = item func_args[3] = ( - item[3] if item[3] else func_args[3] + item["language"] if "language" in item and item["language"] else func_args[3] ) # override phoneme language if specified by the dataset formatter - phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) + phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args) return phonemes def compute_input_seq(self, num_workers=0): @@ -286,10 +280,9 @@ class TTSDataset(Dataset): if self.verbose: print(" | > Computing input sequences ...") for idx, item in enumerate(tqdm.tqdm(self.items)): - text, *_ = item sequence = np.asarray( text_to_sequence( - text, + item["text"], [self.cleaners], custom_symbols=self.custom_symbols, tp=self.characters, @@ -337,10 +330,10 @@ class TTSDataset(Dataset): if by_audio_len: lengths = [] for item in self.items: - lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio + lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio lengths = np.array(lengths) else: - lengths = np.array([len(ins[0]) for ins in self.items]) + lengths = np.array([len(ins["text"]) for ins in self.items]) idxs = np.argsort(lengths) new_items = [] @@ -555,7 +548,7 @@ class PitchExtractor: def __init__( self, - items: List[List], + items: List[Dict], verbose=False, ): self.items = items @@ -614,10 +607,9 @@ class PitchExtractor: item = args[0] ap = args[1] cache_path = args[2] - _, wav_file, *_ = item - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path) if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file) return pitch return None diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 1f23f85e..28eb0e0f 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" wav_file = os.path.join(root_path, folder_name, wav_file) - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) else: # M-AI-Labs have some missing samples, so just print the warning print("> File %s does not exist!" % (wav_file)) @@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -133,7 +133,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append([text, wav_file, f"ljspeech-{idx}"]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"}) return items @@ -150,7 +150,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg if not os.path.exists(wav_file): print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -165,7 +165,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -179,7 +179,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -193,7 +193,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument utt_id = line.split()[1] text = line[line.find('"') + 1 : line.rfind('"') - 1] wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -213,7 +213,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None): if speaker_name in ignored_speakers: continue wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append([text, wav_file, "MCV_" + speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) return items @@ -240,7 +240,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - items.append([text, wav_file, "LTTS_" + speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) for item in items: assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -259,7 +259,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar skipped_files.append(wav_file) continue text = cols[1].strip() - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items @@ -281,7 +281,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append([text, wav_file, speaker_id]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) return items @@ -299,7 +299,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append([text, wav_file, "VCTK_" + speaker_id]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) return items @@ -334,7 +334,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker in ignored_speakers: continue - items.append([text, wav_file, "MLS_" + speaker]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) return items @@ -404,7 +404,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) - items.append([text, wav_path, speaker_name]) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) return items @@ -418,5 +418,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2].replace(" ", "") - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index 4766c77d..c0270e40 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -113,7 +113,7 @@ class ActNorm(nn.Module): denom = torch.sum(x_mask, [0, 2]) m = torch.sum(x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m ** 2) + v = m_sq - (m**2) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 0c87e9df..aeb45c7b 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -65,7 +65,7 @@ class WN(torch.nn.Module): self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") # intermediate layers for i in range(num_layers): - dilation = dilation_rate ** i + dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) in_layer = torch.nn.Conv1d( hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 36ed668b..3b43e527 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -101,7 +101,7 @@ class Encoder(nn.Module): self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index ba6aa1e2..0f837abf 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module): # relative positional encoding layers if rel_attn_window_size is not None: n_heads_rel = 1 if heads_share else num_heads - rel_stddev = self.k_channels ** -0.5 + rel_stddev = self.k_channels**-0.5 emb_rel_k = nn.Parameter( torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev ) @@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module): batch, heads, length, _ = x.size() # padd along column x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) # add 0's in the beginning that will skew the elements after reshape x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 7de45041..d770a536 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module): def _make_ga_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() - return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) @staticmethod def _make_masks(ilens, olens): @@ -665,7 +665,7 @@ class VitsDiscriminatorLoss(nn.Module): dr = dr.float() dg = dg.float() real_loss = torch.mean((1 - dr) ** 2) - fake_loss = torch.mean(dg ** 2) + fake_loss = torch.mean(dg**2) loss += real_loss + fake_loss real_losses.append(real_loss.item()) fake_losses.append(fake_loss.item()) diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 01a81e0b..7d751bc0 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module): # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] - scores = scores / (self.key_dim ** 0.5) + scores = scores / (self.key_dim**0.5) scores = F.softmax(scores, dim=3) # out = score * V diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index ef426ace..7c225344 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -57,7 +57,7 @@ class TextEncoder(nn.Module): self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) if language_emb_dim: hidden_channels += language_emb_dim diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index 120d0944..738ee341 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module): self.norms_1 = nn.ModuleList() self.norms_2 = nn.ModuleList() for i in range(num_layers): - dilation = kernel_size ** i + dilation = kernel_size**i padding = (kernel_size * dilation - dilation) // 2 self.convs_sep.append( nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) @@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module): # posterior encoder - neg log likelihood logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) nll_posterior_encoder = ( - torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q ) z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask @@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module): z = torch.flip(z, [1]) # flow layers - neg log likelihood - nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot return nll_flow_layers + nll_posterior_encoder flows = list(reversed(self.flows)) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index c1e4c2ac..7dbfdd09 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -206,9 +206,9 @@ class GlowTTS(BaseTTS): with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) @@ -255,9 +255,9 @@ class GlowTTS(BaseTTS): # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cb349ca2..ae24a99e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -4,7 +4,6 @@ from itertools import chain from typing import Dict, List, Tuple import torch - import torchaudio from coqpit import Coqpit from torch import nn @@ -424,9 +423,9 @@ class Vits(BaseTTS): and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] ): self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.audio_config["sample_rate"], - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) + orig_freq=self.audio_config["sample_rate"], + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) else: self.audio_transform = None @@ -591,9 +590,9 @@ class Vits(BaseTTS): with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() @@ -692,10 +691,17 @@ class Vits(BaseTTS): if self.args.use_sdp: logw = self.duration_predictor( - x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, ) else: - logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb) + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index fc7eec57..a4f41be5 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -113,7 +113,7 @@ def _set_file_path(path): def get_language_weighted_sampler(items: list): - language_names = np.array([item[3] for item in items]) + language_names = np.array([item["language"] for item in items]) unique_language_names = np.unique(language_names).tolist() language_ids = [unique_language_names.index(l) for l in language_names] language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 07076d90..441296ac 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -118,7 +118,7 @@ class SpeakerManager: Returns: Tuple[Dict, int]: speaker IDs and number of speakers. """ - speakers = sorted({item[2] for item in items}) + speakers = sorted({item["speaker_name"] for item in items}) speaker_ids = {name: i for i, name in enumerate(speakers)} num_speakers = len(speaker_ids) return speaker_ids, num_speakers @@ -414,7 +414,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, def get_speaker_weighted_sampler(items: list): - speaker_names = np.array([item[2] for item in items]) + speaker_names = np.array([item["speaker_name"] for item in items]) unique_speaker_names = np.unique(speaker_names).tolist() speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 883efdb8..ab2c6991 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -8,7 +8,7 @@ from torch.autograd import Variable def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) return gauss / gauss.sum() @@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 + C1 = 0.01**2 + C2 = 0.03**2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 25f93c34..0253f918 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method ) M = o[:, :, :, 0] P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) if self.power is not None: - S = S ** self.power + S = S**self.power if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) @@ -634,8 +634,8 @@ class AudioProcessor(object): S = self._db_to_amp(S) # Reconstruct phase if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) - return self._griffin_lim(S ** self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" @@ -643,8 +643,8 @@ class AudioProcessor(object): S = self._db_to_amp(D) S = self._mel_to_linear(S) # Convert back to linear if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) - return self._griffin_lim(S ** self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: """Convert a full scale linear spectrogram output of a network to a melspectrogram. @@ -781,7 +781,7 @@ class AudioProcessor(object): @staticmethod def _rms_norm(wav, db_level=-27): r = 10 ** (db_level / 20) - a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2)) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) return wav * a def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: @@ -853,7 +853,7 @@ class AudioProcessor(object): @staticmethod def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: - mu = 2 ** qc - 1 + mu = 2**qc - 1 # wav_abs = np.minimum(np.abs(wav), 1.0) signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) # Quantize signal to the specified number of levels. @@ -865,13 +865,13 @@ class AudioProcessor(object): @staticmethod def mulaw_decode(wav, qc): """Recovers waveform from quantized values.""" - mu = 2 ** qc - 1 + mu = 2**qc - 1 x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) return x @staticmethod def encode_16bits(x): - return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) @staticmethod def quantize(x: np.ndarray, bits: int) -> np.ndarray: @@ -884,12 +884,12 @@ class AudioProcessor(object): Returns: np.ndarray: Quantized waveform. """ - return (x + 1.0) * (2 ** bits - 1) / 2 + return (x + 1.0) * (2**bits - 1) / 2 @staticmethod def dequantize(x, bits): """Dequantize a waveform from the given number of bits.""" - return 2 * x / (2 ** bits - 1) - 1 + return 2 * x / (2**bits - 1) - 1 def _log(x, base): diff --git a/TTS/utils/download.py b/TTS/utils/download.py index 241a106b..de9b31a7 100644 --- a/TTS/utils/download.py +++ b/TTS/utils/download.py @@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> while True: # Read by chunk to avoid filling memory - chunk = file_obj.read(1024 ** 2) + chunk = file_obj.read(1024**2) if not chunk: break hash_func.update(chunk) diff --git a/TTS/utils/training.py b/TTS/utils/training.py index aa5651c5..9f01b310 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -39,7 +39,7 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler): def get_lr(self): step = max(self.last_epoch, 1) return [ - base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5) + base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) for base_lr in self.base_lrs ] @@ -63,7 +63,7 @@ def lr_decay(init_lr, global_step, warmup_steps): It is only being used by the Speaker Encoder trainer.""" warmup_steps = float(warmup_steps) step = global_step + 1.0 - lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) + lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) return lr diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py index f536ba98..7845dd6b 100644 --- a/TTS/vocoder/configs/parallel_wavegan_config.py +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -127,5 +127,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html - lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1} + ) scheduler_after_epoch: bool = False diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index d648b68c..2c771cf0 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -111,7 +111,7 @@ class WaveRNNDataset(Dataset): elif isinstance(self.mode, int): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) - x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 + x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/lvc_block.py b/TTS/vocoder/layers/lvc_block.py index 0e29ee3c..8913a113 100644 --- a/TTS/vocoder/layers/lvc_block.py +++ b/TTS/vocoder/layers/lvc_block.py @@ -126,9 +126,9 @@ class LVCBlock(torch.nn.Module): ) for i in range(conv_layers): - padding = (3 ** i) * int((conv_kernel_size - 1) / 2) + padding = (3**i) * int((conv_kernel_size - 1) / 2) conv = torch.nn.Conv1d( - in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i + in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i ) self.convs.append(conv) diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py index 7fd999d9..4bb328e9 100644 --- a/TTS/vocoder/layers/melgan.py +++ b/TTS/vocoder/layers/melgan.py @@ -12,7 +12,7 @@ class ResidualStack(nn.Module): self.blocks = nn.ModuleList() for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size ** idx + layer_dilation = layer_kernel_size**idx layer_padding = base_padding * layer_dilation self.blocks += [ nn.Sequential( diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py index 889e8aa6..51142e5e 100644 --- a/TTS/vocoder/layers/parallel_wavegan.py +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -72,6 +72,6 @@ class ResidualBlock(torch.nn.Module): s = self.conv1x1_skip(x) # for residual connection - x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) + x = (self.conv1x1_out(x) + residual) * (0.5**2) return x, s diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 4ce743b3..fc15f3af 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -207,7 +207,7 @@ class HifiganGenerator(torch.nn.Module): self.ups.append( weight_norm( ConvTranspose1d( - upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index e60baa9d..80b47870 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -36,7 +36,7 @@ class MelganGenerator(nn.Module): # upsampling layers and residual stacks for idx, upsample_factor in enumerate(upsample_factors): - layer_in_channels = base_channels // (2 ** idx) + layer_in_channels = base_channels // (2**idx) layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index 9cc1061c..adf1bdae 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -35,7 +35,7 @@ class ParallelWaveganDiscriminator(nn.Module): if i == 0: dilation = 1 else: - dilation = i if dilation_factor == 1 else dilation_factor ** i + dilation = i if dilation_factor == 1 else dilation_factor**i conv_in_channels = conv_channels padding = (kernel_size - 1) // 2 * dilation conv_layer = [ diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index b8e78d03..ee9d8ad5 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -142,7 +142,7 @@ class ParallelWaveganGenerator(torch.nn.Module): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 8a66c537..2ee28c7b 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -130,7 +130,7 @@ class UnivnetGenerator(torch.nn.Module): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index ed4f4b37..00142c91 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -153,7 +153,7 @@ class Wavegrad(BaseVocoder): noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) noise_scale = noise_scale.unsqueeze(1) noise = torch.randn_like(y_0) - noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] def compute_noise_level(self, beta): @@ -161,8 +161,8 @@ class Wavegrad(BaseVocoder): self.num_steps = len(beta) alpha = 1 - beta alpha_hat = np.cumprod(alpha) - noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0) - noise_level = alpha_hat ** 0.5 + noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) + noise_level = alpha_hat**0.5 # pylint: disable=not-callable self.beta = torch.tensor(beta.astype(np.float32)) @@ -170,7 +170,7 @@ class Wavegrad(BaseVocoder): self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) self.noise_level = torch.tensor(noise_level.astype(np.float32)) - self.c1 = 1 / self.alpha ** 0.5 + self.c1 = 1 / self.alpha**0.5 self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 1977efb6..b5b2343a 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -225,7 +225,7 @@ class Wavernn(BaseVocoder): super().__init__(config) if isinstance(self.args.mode, int): - self.n_classes = 2 ** self.args.mode + self.n_classes = 2**self.args.mode elif self.args.mode == "mold": self.n_classes = 3 * 10 elif self.args.mode == "gauss": diff --git a/tests/data_tests/test_dataset_formatters.py b/tests/data_tests/test_dataset_formatters.py index bd83002c..30fb79a8 100644 --- a/tests/data_tests/test_dataset_formatters.py +++ b/tests/data_tests/test_dataset_formatters.py @@ -5,13 +5,13 @@ from tests import get_tests_input_path from TTS.tts.datasets.formatters import common_voice -class TestPreprocessors(unittest.TestCase): +class TestTTSFormatters(unittest.TestCase): def test_common_voice_preprocessor(self): # pylint: disable=no-self-use root_path = get_tests_input_path() meta_file = "common_voice.tsv" items = common_voice(root_path, meta_file) - assert items[0][0] == "The applicants are invited for coffee and visa is given immediately." - assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") + assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately." + assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") - assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts." - assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") + assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts." + assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") diff --git a/tests/vocoder_tests/test_vocoder_wavernn.py b/tests/vocoder_tests/test_vocoder_wavernn.py index d4a7b8dd..966ea3dd 100644 --- a/tests/vocoder_tests/test_vocoder_wavernn.py +++ b/tests/vocoder_tests/test_vocoder_wavernn.py @@ -46,6 +46,6 @@ def test_wavernn(): config.model_args.mode = 4 model = Wavernn(config) output = model(dummy_x, dummy_m) - assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape + assert np.all(output.shape == (2, 1280, 2**4)), output.shape output = model.inference(dummy_y, True, 5500, 550) assert np.all(output.shape == (256 * (y_size - 1),)) From 2db67e3356a9514d8bb1f2372ac9121f2bf1c50c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 14 Feb 2022 10:49:25 +0000 Subject: [PATCH 006/214] Update dataset formatting docs --- docs/source/formatting_your_dataset.md | 57 +++++++++++++++++++++++--- docs/source/index.md | 1 - 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/docs/source/formatting_your_dataset.md b/docs/source/formatting_your_dataset.md index 3db38af0..5b1d9801 100644 --- a/docs/source/formatting_your_dataset.md +++ b/docs/source/formatting_your_dataset.md @@ -58,23 +58,68 @@ If you use a different dataset format then the LJSpeech or the other public data If your dataset is in a new language or it needs special normalization steps, then you need a new `text_cleaner`. -What you get out of a `formatter` is a `List[List[]]` in the following format. +What you get out of a `formatter` is a `List[Dict]` in the following format. ``` >>> formatter(metafile_path) -[["audio1.wav", "This is my sentence.", "MyDataset"], -["audio1.wav", "This is maybe a sentence.", "MyDataset"], -... +[ + {"audio_file":"audio1.wav", "text":"This is my sentence.", "speaker_name":"MyDataset", "language": "lang_code"}, + {"audio_file":"audio1.wav", "text":"This is maybe a sentence.", "speaker_name":"MyDataset", "language": "lang_code"}, + ... ] ``` -Each sub-list is parsed as ```["", "", "]```. +Each sub-list is parsed as ```{"", "", "]```. `````` is the dataset name for single speaker datasets and it is mainly used in the multi-speaker models to map the speaker of the each sample. But for now, we only focus on single speaker datasets. -The purpose of a `formatter` is to parse your metafile and load the audio file paths and transcriptions. Then, its output passes to a `Dataset` object. It computes features from the audio signals, calls text normalization routines, and converts raw text to +The purpose of a `formatter` is to parse your manifest file and load the audio file paths and transcriptions. +Then, the output is passed to the `Dataset`. It computes features from the audio signals, calls text normalization routines, and converts raw text to phonemes if needed. +## Loading your dataset + +Load one of the dataset supported by 🐸TTS. + +```python +from TTS.tts.configs.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples + + +# dataset config for one of the pre-defined datasets +dataset_config = BaseDatasetConfig( + name="vctk", meta_file_train="", language="en-us", path="dataset-path") +) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) +``` + +Load a custom dataset with a custom formatter. + +```python +from TTS.tts.datasets import load_tts_samples + + +# custom formatter implementation +def formatter(root_path, manifest_file, **kwargs): # pylint: disable=unused-argument + """Assumes each line as ```|``` + """ + txt_file = os.path.join(root_path, manifest_file) + items = [] + speaker_name = "my_speaker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0]) + text = cols[1] + items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name}) + return items + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True, formatter=formatter) +``` + See `TTS.tts.datasets.TTSDataset`, a generic `Dataset` implementation for the `tts` models. See `TTS.vocoder.datasets.*`, for different `Dataset` implementations for the `vocoder` models. diff --git a/docs/source/index.md b/docs/source/index.md index 756cea8e..9dc5bfce 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -27,7 +27,6 @@ formatting_your_dataset what_makes_a_good_dataset tts_datasets - converting_torch_to_tf .. toctree:: :maxdepth: 2 From 06cad27e31445331c3c27c32fde69c9819249153 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Feb 2022 18:20:47 +0000 Subject: [PATCH 007/214] Add Glow-TTS multi-speaker unit test --- .../test_glow_tts_speaker_emb_train.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/tts_tests/test_glow_tts_speaker_emb_train.py diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py new file mode 100644 index 00000000..9a1a1910 --- /dev/null +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -0,0 +1,57 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.utils.trainer_utils import get_last_checkpoint + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = GlowTTSConfig( + batch_size=2, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + data_dep_init_steps=1.0, + use_speaker_embedding=True, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) From ba6e56e01c8fd1a42177fc5717e29fa0d6990e05 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Feb 2022 19:25:29 +0000 Subject: [PATCH 008/214] Fix Glow-TTS multi-speaker inference --- TTS/tts/models/glow_tts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 7dbfdd09..8f3b3804 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -170,6 +170,8 @@ class GlowTTS(BaseTTS): if g is not None: if hasattr(self, "emb_g"): # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] else: # use d-vector From 759f9ac76a22af865391daa0d7c46d0670c422ee Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Feb 2022 20:03:36 +0000 Subject: [PATCH 009/214] Add Glow-TTS d-vectors training unit test --- .../test_glow_tts_d-vectors_train.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/tts_tests/test_glow_tts_d-vectors_train.py diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests/test_glow_tts_d-vectors_train.py new file mode 100644 index 00000000..5b82eebb --- /dev/null +++ b/tests/tts_tests/test_glow_tts_d-vectors_train.py @@ -0,0 +1,60 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.utils.trainer_utils import get_last_checkpoint + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = GlowTTSConfig( + batch_size=2, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + data_dep_init_steps=1.0, + use_speaker_embedding=False, + use_d_vector_file=True, + d_vector_file="tests/data/ljspeech/speakers.json", + d_vector_dim=256, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) From 5cca4aa8aebe689cebd1dbda70ad648b42ee5407 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Feb 2022 20:16:52 +0000 Subject: [PATCH 010/214] Add FastPitch Speaker embedding train unit test --- .../test_fast_pitch_speaker_emb_train.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/tts_tests/test_fast_pitch_speaker_emb_train.py diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py new file mode 100644 index 00000000..c526e33a --- /dev/null +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -0,0 +1,69 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.fast_pitch_config import FastPitchConfig + +config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = FastPitchConfig( + audio=audio_config, + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + use_speaker_embedding=True, + test_sentences=[ + "Be a voice, not an echo.", + ], +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) From fc7081fc5e05fff2b4d52856588b947694bd199f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Feb 2022 21:06:08 +0000 Subject: [PATCH 011/214] Add Inference test using TTS API in all models unit tests --- tests/data/ljspeech/f0_cache/pitch_stats.npy | Bin 0 -> 424 bytes tests/tts_tests/test_align_tts_train.py | 9 +++++++++ .../test_fast_pitch_speaker_emb_train.py | 11 ++++++++++ tests/tts_tests/test_fast_pitch_train.py | 9 +++++++++ .../test_glow_tts_d-vectors_train.py | 10 +++++++++ .../test_glow_tts_speaker_emb_train.py | 10 +++++++++ tests/tts_tests/test_glow_tts_train.py | 9 +++++++++ tests/tts_tests/test_speedy_speech_train.py | 9 +++++++++ .../test_tacotron2_d-vectors_train.py | 11 ++++++++++ .../test_tacotron2_speaker_emb_train.py | 11 ++++++++++ tests/tts_tests/test_tacotron2_train.py | 9 +++++++++ tests/tts_tests/test_tacotron_train.py | 9 +++++++++ ...st_vits_multilingual_speaker_emb_train.py} | 15 +++++++++++++- .../test_vits_multilingual_train-d_vectors.py | 19 +++++++++++++++--- .../tts_tests/test_vits_speaker_emb_train.py | 15 ++++++++++++-- tests/tts_tests/test_vits_train.py | 9 +++++++++ 16 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 tests/data/ljspeech/f0_cache/pitch_stats.npy rename tests/tts_tests/{test_vits_multilingual_train.py => test_vits_multilingual_speaker_emb_train.py} (74%) diff --git a/tests/data/ljspeech/f0_cache/pitch_stats.npy b/tests/data/ljspeech/f0_cache/pitch_stats.npy new file mode 100644 index 0000000000000000000000000000000000000000..aaa385c3c07d9eb8739ab504b8bdb7e34f0002d5 GIT binary patch literal 424 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1Jt-^H=`z0QA$Z=K`K`vTLcpW1B1UsA$w;> zdm%?qA*Y5na|9z$tfr95&(F{6KM;TkZ~Kx$?xfDxLY~?}UX2JAppx9w#Joa29BwO4 zPRvOx;wt3NfY^~{Q78biLoldN2xf;(p)jf)3?+pkNzNQkzxYgW;mLI<6m2UM3n~ Date: Sat, 19 Feb 2022 12:15:03 +0000 Subject: [PATCH 012/214] Fix unit tests issue --- tests/tts_tests/test_fast_pitch_speaker_emb_train.py | 4 +++- tests/tts_tests/test_fast_pitch_train.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py index 1b777803..59e90e0a 100644 --- a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -7,7 +7,7 @@ from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.utils.trainer_utils import get_last_checkpoint -config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( @@ -45,6 +45,8 @@ config = FastPitchConfig( ], ) config.audio.do_trim_silence = True +config.use_speaker_embedding = True +config.model_args.use_speaker_embedding = True config.audio.trim_db = 60 config.save_json(config_path) diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 9aae5bbd..bbfbb823 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -7,7 +7,7 @@ from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.utils.trainer_utils import get_last_checkpoint -config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( @@ -42,8 +42,11 @@ config = FastPitchConfig( test_sentences=[ "Be a voice, not an echo.", ], + use_speaker_embedding=False, ) config.audio.do_trim_silence = True +config.use_speaker_embedding = False +config.model_args.use_speaker_embedding = False config.audio.trim_db = 60 config.save_json(config_path) @@ -58,6 +61,7 @@ command_train = ( "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " "--coqpit.test_delay_epochs 0" ) + run_cli(command_train) # Find latest folder From 531821545e40e2a8fba7d351be52b58bfd61458f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 19 Feb 2022 12:21:32 +0000 Subject: [PATCH 013/214] Fix inference test issue --- tests/inference_tests/test_synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index 5972dc90..97878574 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -6,7 +6,7 @@ from TTS.tts.models import setup_model from TTS.utils.io import save_checkpoint from TTS.utils.synthesizer import Synthesizer -from .. import get_tests_output_path +from tests import get_tests_output_path class SynthesizerTest(unittest.TestCase): From 05fffb0ebc549a75f77b37fb71f7f42002f4d1c2 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 19 Feb 2022 14:42:24 +0000 Subject: [PATCH 014/214] Add inference unit test on GitHub workflow --- .github/workflows/inference_tests.yml | 46 +++++++++++++++++++++++++++ Makefile | 3 ++ 2 files changed, 49 insertions(+) create mode 100644 .github/workflows/inference_tests.yml diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml new file mode 100644 index 00000000..3f08b904 --- /dev/null +++ b/.github/workflows/inference_tests.yml @@ -0,0 +1,46 @@ +name: inference_tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make inference_tests diff --git a/Makefile b/Makefile index 2632dbab..bb849981 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,9 @@ test_aux: ## run aux tests. test_zoo: ## run zoo tests. nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id +inference_tests: ## run inference tests. + nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id + test_failed: ## only run tests failed the last time. nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed From bc5db13d067332baecb932f54fe4abe5398be016 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 19 Feb 2022 19:24:00 +0000 Subject: [PATCH 015/214] Fix the bug in extract tts spectrogram script --- TTS/bin/extract_tts_spectrograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 7b489fd6..386cf332 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -138,7 +138,7 @@ def inference( aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, ) model_output = outputs["model_outputs"] - model_output = model_output.transpose(1, 2).detach().cpu().numpy() + model_output = model_output.detach().cpu().numpy() elif "tacotron" in model_name: aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} From 28a746497560dc3f1f3415827ef38d7d9d72dbbf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 21 Feb 2022 05:59:36 -0300 Subject: [PATCH 016/214] Fix the bug in split dataset function (#1251) * Fix the bug in split_dataset * Make eval_split_size configurable * Change test_loader to use load_tts_samples function * Change eval_split_portion to eval_split_size and permits to set the absolute number of samples in eval * Fix samplers unit test * Add data unit test on GitHub workflow --- .github/workflows/data_tests.yml | 46 +++++++++++++++++++++++++++++ Makefile | 3 ++ TTS/bin/extract_tts_spectrograms.py | 2 +- TTS/bin/find_unique_chars.py | 2 +- TTS/bin/find_unique_phonemes.py | 2 +- TTS/bin/train_tts.py | 2 +- TTS/tts/configs/shared_configs.py | 10 +++++++ TTS/tts/datasets/__init__.py | 41 +++++++++++++++++++------ TTS/tts/datasets/formatters.py | 6 +++- tests/data_tests/test_loader.py | 27 ++++++++++++----- tests/data_tests/test_samplers.py | 4 +-- 11 files changed, 121 insertions(+), 24 deletions(-) create mode 100644 .github/workflows/data_tests.yml diff --git a/.github/workflows/data_tests.yml b/.github/workflows/data_tests.yml new file mode 100644 index 00000000..296aa570 --- /dev/null +++ b/.github/workflows/data_tests.yml @@ -0,0 +1,46 @@ +name: data-tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make data_tests diff --git a/Makefile b/Makefile index 2632dbab..6752fa04 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,9 @@ test_aux: ## run aux tests. test_zoo: ## run zoo tests. nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id +data_tests: ## run data tests. + nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id + test_failed: ## only run tests failed the last time. nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 7b489fd6..e21f57c9 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -229,7 +229,7 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # load data instances - meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval) + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) # use eval and training partitions meta_data = meta_data_train + meta_data_eval diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index fb98bab5..541e971b 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -23,7 +23,7 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) items = train_items + eval_items diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 02a783c7..ad567434 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -39,7 +39,7 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) items = train_items + eval_items print("Num items:", len(items)) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index a7ce8ef3..16251fdd 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -42,7 +42,7 @@ def main(): config = register_config(config_base.model)() # load training samples - train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True) + train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size) # setup audio processor ap = AudioProcessor(**config.audio) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 60ef7276..65ed21de 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -183,6 +183,13 @@ class BaseTTSConfig(BaseTrainingConfig): test_sentences (List[str]): List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) @@ -218,3 +225,6 @@ class BaseTTSConfig(BaseTrainingConfig): lr_scheduler_params: dict = field(default_factory=lambda: {}) # testing test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 455413fa..d80e92c9 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -9,25 +9,40 @@ from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * -def split_dataset(items): +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. Args: - items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). """ - speakers = [item[-1] for item in items] + speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = min(500, int(len(items) * 0.01)) - assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert eval_split_size > 0, " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(1/len(items)) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: items_eval = [] - speakers = [item[-1] for item in items] + speakers = [item["speaker_name"] for item in items] speaker_counter = Counter(speakers) while len(items_eval) < eval_split_size: item_idx = np.random.randint(0, len(items)) - speaker_to_be_removed = items[item_idx][-1] + speaker_to_be_removed = items[item_idx]["speaker_name"] if speaker_counter[speaker_to_be_removed] > 1: items_eval.append(items[item_idx]) speaker_counter[speaker_to_be_removed] -= 1 @@ -37,7 +52,8 @@ def split_dataset(items): def load_tts_samples( - datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None + datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None, + eval_split_max_size=None, eval_split_size=0.01 ) -> Tuple[List[List], List[List]]: """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based @@ -55,6 +71,13 @@ def load_tts_samples( `[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as example. Defaults to None. + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + Returns: Tuple[List[List], List[List]: training and evaluation splits of the dataset. """ @@ -84,7 +107,7 @@ def load_tts_samples( meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] else: - meta_data_eval, meta_data_train = split_dataset(meta_data_train) + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 28eb0e0f..5cbc93db 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg txt_file = os.path.join(root_path, meta_file) items = [] with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx%2 == 0: + speaker_id += 1 cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) return items diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 19c2e8f7..d210995d 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -8,8 +8,8 @@ from torch.utils.data import DataLoader from tests import get_tests_output_path from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.datasets import TTSDataset -from TTS.tts.datasets.formatters import ljspeech +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.config.shared_configs import BaseDatasetConfig from TTS.utils.audio import AudioProcessor # pylint: disable=unused-variable @@ -18,11 +18,19 @@ OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") os.makedirs(OUTPATH, exist_ok=True) # create a dummy config for testing data loaders. -c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2) +c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2, use_noise_augment=False) c.r = 5 c.data_path = "tests/data/ljspeech/" ok_ljspeech = os.path.exists(c.data_path) +dataset_config = BaseDatasetConfig( + name="ljspeech_test", # ljspeech_test to multi-speaker + meta_file_train="metadata.csv", + meta_file_val=None, + path=c.data_path, + language="en", +) + DATA_EXIST = True if not os.path.exists(c.data_path): DATA_EXIST = False @@ -37,11 +45,10 @@ class TestTTSDataset(unittest.TestCase): self.ap = AudioProcessor(**c.audio) def _create_dataloader(self, batch_size, r, bgs): - items = ljspeech(c.data_path, "metadata.csv") - # add a default language because now the TTSDataset expect a language - language = "" - items = [[*item, language] for item in items] + # load dataset + meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2) + items = meta_data_train + meta_data_eval dataset = TTSDataset( r, @@ -97,8 +104,12 @@ class TestTTSDataset(unittest.TestCase): # make sure that the computed mels and the waveform match and correctly computed mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) + # remove padding in mel-spectrogram + mel_dataloader = mel_input[0].T.numpy()[:, :mel_lengths[0]] + # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding + mel_new = mel_new[:, :mel_lengths[0]] ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) - mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] + mel_diff = (mel_new - mel_dataloader)[:, 0:ignore_seg] assert abs(mel_diff.sum()) < 1e-5 # check normalization ranges diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 3d8d6c75..497a3fb5 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -39,7 +39,7 @@ random_sampler = torch.utils.data.RandomSampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 @@ -50,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 From 89dd89b5e500cba2257a32c35870b8951584b5b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 22 Feb 2022 12:18:03 +0100 Subject: [PATCH 017/214] Update LJSpeech DCA recipe --- recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index cf00ccc2..8c159d98 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -38,10 +38,10 @@ config = Tacotron2Config( # This is the config that is saved for the future use num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, - ga_alpha=5.0, + ga_alpha=0.0, r=2, attention_type="dynamic_convolution", - double_decoder_consistency=True, + double_decoder_consistency=False, epochs=1000, text_cleaner="phoneme_cleaners", use_phonemes=True, From 4f68ba7127c0d77fbdd33b0701dfcaa164d364b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 22 Feb 2022 13:54:53 +0100 Subject: [PATCH 018/214] Disable extra losses --- recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index 8c159d98..0a285c3b 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -39,6 +39,12 @@ config = Tacotron2Config( # This is the config that is saved for the future use run_eval=True, test_delay_epochs=-1, ga_alpha=0.0, + decoder_loss_alpha=0.25, + postnet_loss_alpha=0.25, + postnet_diff_spec_alpha=0, + decoder_diff_spec_alpha=0, + decoder_ssim_alpha=0, + postnet_ssim_alpha=0, r=2, attention_type="dynamic_convolution", double_decoder_consistency=False, From ca02b82218ed291db4fa5dbb22b8486ba9024531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:23:02 +0100 Subject: [PATCH 019/214] Implement ZH_CH phonemizer --- .../text/phonemizers/zh_cn_phonemizer.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py new file mode 100644 index 00000000..e1bd77c7 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -0,0 +1,61 @@ +from typing import Dict + +from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class ZH_CN_Phonemizer(BasePhonemizer): + """🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Mandarin knowledge should check this implementation + """ + + language = "zh-cn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "zh_cn_phonemizer" + + def phonemize_zh_cn(self, text: str, separator: str = "|") -> str: + ph = chinese_text_to_phonemes(text, separator) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_zh_cn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"zh-cn": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + text = "这是,样本中文。" + e = ZH_CN_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(text) + "`") From 172ba0c5e7d402512627036b999a628f0051e0fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:23:17 +0100 Subject: [PATCH 020/214] Implement JA_JP phonemizer --- .../text/phonemizers/ja_jp_phonemizer.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py new file mode 100644 index 00000000..fcd170ba --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -0,0 +1,52 @@ +from typing import Dict + +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class JA_JP_Phonemizer(BasePhonemizer): + """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` + + TODO: someone with JA knowledge should check this implementation + """ + + language = "ja-jp" + + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=False, **kwargs): + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ja_jp_phonemizer" + + def phonemize_jajp(self, text: str, separator: str = "|") -> str: + ph = japanese_text_to_phonemes(text) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_jajp(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"ja-jp": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + text = "これは、電話をかけるための私の日本語の例のテキストです。" + e = JA_JP_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(text) + "`") From e03a05c8160974471cfe002916c843bdfa30414e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:23:44 +0100 Subject: [PATCH 021/214] Implement gruut wrapper --- .../utils/text/phonemizers/gruut_wrapper.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/gruut_wrapper.py diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py new file mode 100644 index 00000000..a1ad1b80 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -0,0 +1,145 @@ +import importlib +from os import stat +from typing import List + +import gruut +from gruut_ipa import IPA + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +# Table for str.translate to fix gruut/TTS phoneme mismatch +GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") + + +class Gruut(BasePhonemizer): + """Gruut wrapper for G2P + + Args: + language (str): + Valid language code for the used backend. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`. + + keep_puncs (bool): + If true, keep the punctuations after phonemization. Defaults to True. + + use_espeak_phonemes (bool): + If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False. + + keep_stress (bool): + If true, keep the stress characters after phonemization. Defaults to False. + """ + + def __init__( + self, + language: str, + punctuations=Punctuation.default_puncs(), + keep_puncs=True, + use_espeak_phonemes=False, + keep_stress=False, + ): + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + self.use_espeak_phonemes = use_espeak_phonemes + self.keep_stress = keep_stress + + @staticmethod + def name(): + return "gruut" + + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: + """Convert input text to phonemes. + + Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters + that constitude a single sound. + + It doesn't affect 🐸TTS since it individually converts each character to token IDs. + + Examples:: + "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ` + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + ph_list = [] + for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes): + for word in sentence: + if word.is_break: + # Use actual character for break phoneme (e.g., comma) + if ph_list: + # Join with previous word + ph_list[-1].append(word.text) + else: + # First word is punctuation + ph_list.append([word.text]) + elif word.phonemes: + # Add phonemes for word + word_phonemes = [] + + for word_phoneme in word.phonemes: + if not self.keep_stress: + # Remove primary/secondary stress + word_phoneme = IPA.without_stress(word_phoneme) + + word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) + + if word_phoneme: + # Flatten phonemes + word_phonemes.extend(word_phoneme) + + if word_phonemes: + ph_list.append(word_phonemes) + + ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list] + ph = f"{separator} ".join(ph_words) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_gruut(text, separator, tie=False) + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return gruut.is_language_supported(language) + + @staticmethod + def supported_languages() -> List: + """Get a dictionary of supported languages. + + Returns: + List: List of language codes. + """ + return list(gruut.get_supported_languages()) + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return gruut.__version__ + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return importlib.util.find_spec("gruut") is not None + + +if __name__ == "__main__": + e = Gruut(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = Gruut(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = Gruut(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how, are you today?") + "`") From 5e4f78add387f0130b971a8f0de597566711f3ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:24:06 +0100 Subject: [PATCH 022/214] Implement espeak wrapper --- .../utils/text/phonemizers/espeak_wrapper.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/espeak_wrapper.py diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py new file mode 100644 index 00000000..59c4a8ee --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -0,0 +1,173 @@ +import logging +import subprocess +import tempfile +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + + +def is_tool(name): + from shutil import which + + return which(name) is not None + + +if is_tool("espeak-ng"): + _DEF_ESPEAK_LIB = "espeak-ng" +elif is_tool("espeak"): + _DEF_ESPEAK_LIB = "espeak" +else: + _DEF_ESPEAK_LIB = None + + +def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: + cmd = [ + espeak_lib, + "-b", + "1", # UTF8 text encoding + ] + cmd.extend(args) + logging.debug("espeakng: executing %s" % repr(cmd)) + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + res = iter(p.stdout.readline, b"") + if not sync: + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + return res + res2 = [] + for line in res: + res2.append(line) + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + p.wait() + return res2 + + +class ESpeak(BasePhonemizer): + """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P + + Args: + language (str): + Valid language code for the used backend. + + backend (str): + Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically + prefering `espeak-ng` over `espeak`. Defaults to None. + + punctuations (str): + Characters to be treated as punctuation. Defaults to Punctuation.default_puncs(). + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to True. + """ + + _ESPEAK_LIB = _DEF_ESPEAK_LIB + + def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): + if self._ESPEAK_LIB is None: + raise Exception("Unknown backend: %s" % backend) + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + + def auto_set_espeak_lib(self) -> None: + if is_tool("espeak-ng"): + self._ESPEAK_LIB = "espeak-ng" + elif is_tool("espeak"): + self._ESPEAK_LIB = "espeak" + else: + raise Exception("Cannot set backend automatically. espeak-ng or espeak not found") + + @staticmethod + def name(): + return "espeak" + + def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: + """Convert input text to phonemes. + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + # set arguments + args = ["-q", "-v", f"{self._language}"] + if tie: + args.append("--ipa=1") # use '͡' between phonemes + else: + args.append("--ipa=3") # split with '_' + if tie: + args.append("--tie=%s" % tie) + args.append(text) + # compute phonemes + phonemes = "" + for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): + logging.debug("line: %s" % repr(line)) + phonemes += line.decode("utf8").strip() + return phonemes.replace("_", separator) + + def _phonemize(self, text, separator=None): + return self.phonemize_espeak(text, separator, tie=False) + + @staticmethod + def supported_languages() -> Dict: + """Get a dictionary of supported languages. + + Returns: + Dict: Dictionary of language codes. + """ + if _DEF_ESPEAK_LIB is None: + raise {} + args = ["--voices"] + langs = {} + count = 0 + for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): + line = line.decode("utf8").strip() + if count > 0: + cols = line.split() + lang_code = cols[1] + lang_name = cols[3] + langs[lang_code] = lang_name + logging.debug("line: %s" % repr(line)) + count += 1 + return langs + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + args = ["--version"] + for line in self._espeak_exe(args, sync=True): + version = line.decode("utf8").strip().split()[2] + logging.debug("line: %s" % repr(line)) + return version + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return is_tool("espeak") or is_tool("espeak-ng") + + +if __name__ == "__main__": + e = ESpeak(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = ESpeak(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = ESpeak(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how are you today?") + "`") From 80867c8e8c0f07675a70306fa554ac3dd6eb748e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:24:26 +0100 Subject: [PATCH 023/214] Implement multi-phonemizer --- .../text/phonemizers/multi_phonemizer.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/multi_phonemizer.py diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py new file mode 100644 index 00000000..e8b2ce34 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name + + +class MultiPhonemizer: + """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages + + Args: + custom_lang_to_phonemizer (Dict): + Custom phonemizer mapping if you want to change the defaults. In the format of + `{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`. + + TODO: find a way to pass custom kwargs to the phonemizers + """ + + lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER + language = "multi-lingual" + + def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: + self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer) + self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) + + @staticmethod + def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: + lang_to_phonemizer = {} + for k, v in lang_to_phonemizer_name.items(): + phonemizer = get_phonemizer_by_name(v, language=k) + lang_to_phonemizer[k] = phonemizer + return lang_to_phonemizer + + @staticmethod + def name(): + return "multi-phonemizer" + + def phonemize(self, text, language, separator="|"): + return self.lang_to_phonemizer[language].phonemize(text, separator) + + def supported_languages(self) -> List: + return list(self.lang_to_phonemizer_name.keys()) + + +if __name__ == "__main__": + texts = { + "tr": "Merhaba, bu Türkçe bit örnek!", + "en-us": "Hello, this is English example!", + "de": "Hallo, das ist ein Deutches Beipiel!", + "zh-cn": "这是中国的例子", + } + phonemes = {} + ph = MultiPhonemizer() + for lang, text in texts.items(): + phoneme = ph.phonemize(text, lang) + phonemes[lang] = phoneme + print(phonemes) From dcd01356e05df4bfad43dbf698fd3c9f5d246bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:25:10 +0100 Subject: [PATCH 024/214] Create `text/english` folder --- TTS/tts/utils/text/english/__init__.py | 0 TTS/tts/utils/text/english/abbreviations.py | 26 ++++++ TTS/tts/utils/text/english/number_norm.py | 97 +++++++++++++++++++++ TTS/tts/utils/text/english/time_norm.py | 47 ++++++++++ 4 files changed, 170 insertions(+) create mode 100644 TTS/tts/utils/text/english/__init__.py create mode 100644 TTS/tts/utils/text/english/abbreviations.py create mode 100644 TTS/tts/utils/text/english/number_norm.py create mode 100644 TTS/tts/utils/text/english/time_norm.py diff --git a/TTS/tts/utils/text/english/__init__.py b/TTS/tts/utils/text/english/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/utils/text/english/abbreviations.py b/TTS/tts/utils/text/english/abbreviations.py new file mode 100644 index 00000000..cd93c13c --- /dev/null +++ b/TTS/tts/utils/text/english/abbreviations.py @@ -0,0 +1,26 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in english: +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] diff --git a/TTS/tts/utils/text/english/number_norm.py b/TTS/tts/utils/text/english/number_norm.py new file mode 100644 index 00000000..e8377ede --- /dev/null +++ b/TTS/tts/utils/text/english/number_norm.py @@ -0,0 +1,97 @@ +""" from https://github.com/keithito/tacotron """ + +import re +from typing import Dict + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def __expand_currency(value: str, inflection: Dict[float, str]) -> str: + parts = value.replace(",", "").split(".") + if len(parts) > 2: + return f"{value} {inflection[2]}" # Unexpected format + text = [] + integer = int(parts[0]) if parts[0] else 0 + if integer > 0: + integer_unit = inflection.get(integer, inflection[2]) + text.append(f"{integer} {integer_unit}") + fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if fraction > 0: + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) + text.append(f"{fraction} {fraction_unit}") + if len(text) == 0: + return f"zero {inflection[2]}" + return " ".join(text) + + +def _expand_currency(m: "re.Match") -> str: + currencies = { + "$": { + 0.01: "cent", + 0.02: "cents", + 1: "dollar", + 2: "dollars", + }, + "€": { + 0.01: "cent", + 0.02: "cents", + 1: "euro", + 2: "euros", + }, + "£": { + 0.01: "penny", + 0.02: "pence", + 1: "pound sterling", + 2: "pounds sterling", + }, + "¥": { + # TODO rin + 0.02: "sen", + 2: "yen", + }, + } + unit = m.group(1) + currency = currencies[unit] + value = m.group(2) + return __expand_currency(value, currency) + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if 1000 < num < 3000: + if num == 2000: + return "two thousand" + if 2000 < num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + if num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_currency_re, _expand_currency, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/TTS/tts/utils/text/english/time_norm.py b/TTS/tts/utils/text/english/time_norm.py new file mode 100644 index 00000000..c8ac09e7 --- /dev/null +++ b/TTS/tts/utils/text/english/time_norm.py @@ -0,0 +1,47 @@ +import re + +import inflect + +_inflect = inflect.engine() + +_time_re = re.compile( + r"""\b + ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours + : + ([0-5][0-9]) # minutes + \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm + \b""", + re.IGNORECASE | re.X, +) + + +def _expand_num(n: int) -> str: + return _inflect.number_to_words(n) + + +def _expand_time_english(match: "re.Match") -> str: + hour = int(match.group(1)) + past_noon = hour >= 12 + time = [] + if hour > 12: + hour -= 12 + elif hour == 0: + hour = 12 + past_noon = True + time.append(_expand_num(hour)) + + minute = int(match.group(6)) + if minute > 0: + if minute < 10: + time.append("oh") + time.append(_expand_num(minute)) + am_pm = match.group(7) + if am_pm is None: + time.append("p m" if past_noon else "a m") + else: + time.extend(list(am_pm.replace(".", ""))) + return " ".join(time) + + +def expand_time_english(text: str) -> str: + return re.sub(_time_re, _expand_time_english, text) From c1119bc29159f48fd91d986984772b2bac9fc9cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:25:30 +0100 Subject: [PATCH 025/214] Implement BasePhonemizer --- TTS/tts/utils/text/phonemizers/__init__.py | 51 ++++++++ TTS/tts/utils/text/phonemizers/base.py | 136 +++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 TTS/tts/utils/text/phonemizers/__init__.py create mode 100644 TTS/tts/utils/text/phonemizers/base.py diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py new file mode 100644 index 00000000..c0ef7909 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -0,0 +1,51 @@ +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut +from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer +from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)} + + +ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) +GRUUT_LANGS = list(Gruut.supported_languages()) + + +# Dict setting default phonemizers for each language +DEF_LANG_TO_PHONEMIZER = { + "ja-jp": JA_JP_Phonemizer.name(), + "zh-cn": ZH_CN_Phonemizer.name(), +} + + +# Add Gruut languages +_ = [Gruut.name()] * len(GRUUT_LANGS) +_new_dict = dict(list(zip(GRUUT_LANGS, _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + + +# Add ESpeak languages and override any existing ones +_ = [ESpeak.name()] * len(ESPEAK_LANGS) +_new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + + +def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: + """Initiate a phonemizer by name + + Args: + name (str): + Name of the phonemizer that should match `phonemizer.name()`. + + kwargs (dict): + Extra keyword arguments that should be passed to the phonemizer. + """ + if name == "espeak": + return ESpeak(**kwargs) + if name == "gruut": + return Gruut(**kwargs) + if name == "zh_cn_phonemizer": + return ZH_CN_Phonemizer(**kwargs) + if name == "ja_jp_phonemizer": + return JA_JP_Phonemizer(**kwargs) + raise ValueError(f"Phonemizer {name} not found") diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py new file mode 100644 index 00000000..b370822c --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -0,0 +1,136 @@ +import abc +import itertools +from typing import List, Tuple, Union + +from TTS.tts.utils.text.punctuation import Punctuation + + +class BasePhonemizer(abc.ABC): + """Base phonemizer class + + Args: + language (str): + Language used by the phonemizer. + + punctuations (List[str]): + List of punctuation marks to be preserved. + + keep_puncs (bool): + Whether to preserve punctuation marks or not. + """ + + def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False): + + # ensure the backend is installed on the system + if not self.is_available(): + raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover + + # ensure the backend support the requested language + self._language = self._init_language(language) + + # setup punctuation processing + self._keep_puncs = keep_puncs + self._punctuator = Punctuation(punctuations) + + + def _init_language(self, language): + """Language initialization + + This method may be overloaded in child classes (see Segments backend) + + """ + if not self.is_supported_language(language): + raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend") + return language + + @property + def language(self): + """The language code configured to be used for phonemization""" + return self._language + + @staticmethod + @abc.abstractmethod + def name(): + """The name of the backend""" + + @classmethod + @abc.abstractmethod + def is_available(cls): + """Returns True if the backend is installed, False otherwise""" + + @classmethod + @abc.abstractmethod + def version(cls): + """Return the backend version as a tuple (major, minor, patch)""" + + @abc.abstractmethod + def supported_languages(): + """Return a dict of language codes -> name supported by the backend""" + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return language in self.supported_languages() + + fr""" + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + """ + + @abc.abstractmethod + def _phonemize(self, text, separator): + """The main phonemization method""" + + def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: + """Preprocess the text before phonemization + + Override this if you need a different behaviour + """ + if self._keep_puncs: + # a tuple (text, punctuation marks) + return self._punctuator.strip_to_restore(text) + return [self._punctuator.strip(text)], [] + + def _phonemize_postprocess(self, phonemized, punctuations) -> str: + """Postprocess the raw phonemized output + + Override this if you need a different behaviour + """ + if self._keep_puncs: + return self._punctuator.restore(phonemized, punctuations)[0] + return phonemized[0] + + def phonemize(self, text: str, separator="|") -> str: + """Returns the `text` phonemized for the given language + + Args: + text (str): + Text to be phonemized. + + separator (str): + string separator used between phonemes. Default to '_'. + + Returns: + (str): Phonemized text + """ + text, punctuations = self._phonemize_preprocess(text) + phonemized = [] + for t in text: + p = self._phonemize(t, separator) + phonemized.append(p) + phonemized = self._phonemize_postprocess(phonemized, punctuations) + return phonemized + + def print_logs(self, level: int=0): + indent = "\t" * level + print(f"{indent}| > phoneme language: {self.language}") + print(f"{indent}| > phoneme backend: {self.name()}") From 1bee40af407503c6d51cac40af54dc192ceeb496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:26:22 +0100 Subject: [PATCH 026/214] Create language folders under `TTS.tts.utils.text` --- .../utils/text/chinese_mandarin/phonemizer.py | 4 +- TTS/tts/utils/text/french/__init__.py | 0 .../utils/text/{ => french}/abbreviations.py | 25 ----- TTS/tts/utils/text/number_norm.py | 97 ------------------- TTS/tts/utils/text/time.py | 47 --------- 5 files changed, 2 insertions(+), 171 deletions(-) create mode 100644 TTS/tts/utils/text/french/__init__.py rename TTS/tts/utils/text/{ => french}/abbreviations.py (66%) delete mode 100644 TTS/tts/utils/text/number_norm.py delete mode 100644 TTS/tts/utils/text/time.py diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py index 29cac160..727c881e 100644 --- a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py +++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -19,7 +19,7 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str: return phoneme + tone -def chinese_text_to_phonemes(text: str) -> str: +def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str: tokenized_text = jieba.cut(text, HMM=False) tokenized_text = " ".join(tokenized_text) pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) @@ -34,4 +34,4 @@ def chinese_text_to_phonemes(text: str) -> str: else: # is ponctuation or other results += list(token) - return "|".join(results) + return seperator.join(results) diff --git a/TTS/tts/utils/text/french/__init__.py b/TTS/tts/utils/text/french/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/utils/text/abbreviations.py b/TTS/tts/utils/text/french/abbreviations.py similarity index 66% rename from TTS/tts/utils/text/abbreviations.py rename to TTS/tts/utils/text/french/abbreviations.py index 7e44b90c..f580dfed 100644 --- a/TTS/tts/utils/text/abbreviations.py +++ b/TTS/tts/utils/text/french/abbreviations.py @@ -1,30 +1,5 @@ import re -# List of (regular expression, replacement) pairs for abbreviations in english: -abbreviations_en = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ] -] - # List of (regular expression, replacement) pairs for abbreviations in french: abbreviations_fr = [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) diff --git a/TTS/tts/utils/text/number_norm.py b/TTS/tts/utils/text/number_norm.py deleted file mode 100644 index e8377ede..00000000 --- a/TTS/tts/utils/text/number_norm.py +++ /dev/null @@ -1,97 +0,0 @@ -""" from https://github.com/keithito/tacotron """ - -import re -from typing import Dict - -import inflect - -_inflect = inflect.engine() -_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") -_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -_number_re = re.compile(r"-?[0-9]+") - - -def _remove_commas(m): - return m.group(1).replace(",", "") - - -def _expand_decimal_point(m): - return m.group(1).replace(".", " point ") - - -def __expand_currency(value: str, inflection: Dict[float, str]) -> str: - parts = value.replace(",", "").split(".") - if len(parts) > 2: - return f"{value} {inflection[2]}" # Unexpected format - text = [] - integer = int(parts[0]) if parts[0] else 0 - if integer > 0: - integer_unit = inflection.get(integer, inflection[2]) - text.append(f"{integer} {integer_unit}") - fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if fraction > 0: - fraction_unit = inflection.get(fraction / 100, inflection[0.02]) - text.append(f"{fraction} {fraction_unit}") - if len(text) == 0: - return f"zero {inflection[2]}" - return " ".join(text) - - -def _expand_currency(m: "re.Match") -> str: - currencies = { - "$": { - 0.01: "cent", - 0.02: "cents", - 1: "dollar", - 2: "dollars", - }, - "€": { - 0.01: "cent", - 0.02: "cents", - 1: "euro", - 2: "euros", - }, - "£": { - 0.01: "penny", - 0.02: "pence", - 1: "pound sterling", - 2: "pounds sterling", - }, - "¥": { - # TODO rin - 0.02: "sen", - 2: "yen", - }, - } - unit = m.group(1) - currency = currencies[unit] - value = m.group(2) - return __expand_currency(value, currency) - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if 1000 < num < 3000: - if num == 2000: - return "two thousand" - if 2000 < num < 2010: - return "two thousand " + _inflect.number_to_words(num % 100) - if num % 100 == 0: - return _inflect.number_to_words(num // 100) + " hundred" - return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") - return _inflect.number_to_words(num, andword="") - - -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_currency_re, _expand_currency, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text diff --git a/TTS/tts/utils/text/time.py b/TTS/tts/utils/text/time.py deleted file mode 100644 index c8ac09e7..00000000 --- a/TTS/tts/utils/text/time.py +++ /dev/null @@ -1,47 +0,0 @@ -import re - -import inflect - -_inflect = inflect.engine() - -_time_re = re.compile( - r"""\b - ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours - : - ([0-5][0-9]) # minutes - \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm - \b""", - re.IGNORECASE | re.X, -) - - -def _expand_num(n: int) -> str: - return _inflect.number_to_words(n) - - -def _expand_time_english(match: "re.Match") -> str: - hour = int(match.group(1)) - past_noon = hour >= 12 - time = [] - if hour > 12: - hour -= 12 - elif hour == 0: - hour = 12 - past_noon = True - time.append(_expand_num(hour)) - - minute = int(match.group(6)) - if minute > 0: - if minute < 10: - time.append("oh") - time.append(_expand_num(minute)) - am_pm = match.group(7) - if am_pm is None: - time.append("p m" if past_noon else "a m") - else: - time.extend(list(am_pm.replace(".", ""))) - return " ".join(time) - - -def expand_time_english(text: str) -> str: - return re.sub(_time_re, _expand_time_english, text) From 2fb1f705031d4a9602e5853232d28b53cde89a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:27:25 +0100 Subject: [PATCH 027/214] Implement BaseCharacters, IPAPhonemes, Graphemes --- TTS/tts/utils/text/symbols.py | 289 +++++++++++++++++++++++++++++++--- 1 file changed, 265 insertions(+), 24 deletions(-) diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py index cb708958..ce59031d 100644 --- a/TTS/tts/utils/text/symbols.py +++ b/TTS/tts/utils/text/symbols.py @@ -7,21 +7,33 @@ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA """ +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + def make_symbols( characters, phonemes=None, punctuations="!'(),-.:;? ", - pad="_", - eos="~", - bos="^", + pad="", + eos="", + bos="", + blank="", unique=True, ): # pylint: disable=redefined-outer-name - """Function to create symbols and phonemes - TODO: create phonemes_to_id and symbols_to_id dicts here.""" + """Function to create default characters and phonemes""" _symbols = list(characters) _symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols _symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols _symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols + _symbols = [blank] + _symbols if len(bos) > 0 and blank is not None else _symbols _phonemes = None if phonemes is not None: _phonemes_sorted = ( @@ -35,9 +47,10 @@ def make_symbols( return _symbols, _phonemes -_pad = "_" -_eos = "~" -_bos = "^" +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? " _punctuations = "!'(),-.:;? " @@ -52,24 +65,252 @@ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprase symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) -# Generate ALIEN language -# from random import shuffle -# shuffle(phonemes) + +class BaseCharacters: + """🐸BaseCharacters class + + Every vocabulary class should inherit from this class. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str, + punctuations: str, + pad: str, + eos: str, + bos: str, + blank: str, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._vocab = self.create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._vocab = self.create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._vocab = self.create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._vocab = self.create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._vocab = self.create_vocab() + + @property + def blank(self): + return self._bos + + @bos.setter + def blank(self, bos): + self._bos = bos + self._vocab = self.create_vocab() + + @property + def vocab(self): + return self._vocab + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self.characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self.bos] + _vocab if len(self.bos) > 0 and self.bos is not None else _vocab + _vocab = [self.eos] + _vocab if len(self.bos) > 0 and self.eos is not None else _vocab + _vocab = [self.pad] + _vocab if len(self.bos) > 0 and self.pad is not None else _vocab + self._vocab = _vocab + list(self._punctuations) + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + assert len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + + def char_to_id(self, char: str) -> int: + return self._char_to_id[char] + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + @staticmethod + def init_from_config(config: "Coqpit"): + return BaseCharacters( + **config.characters if config.characters is not None else {}, + ) -def parse_symbols(): - return { - "pad": _pad, - "eos": _eos, - "bos": _bos, - "characters": _characters, - "punctuations": _punctuations, - "phonemes": _phonemes, - } +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + return IPAPhonemes( + **config.characters if config.characters is not None else {}, + ) + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + return Graphemes( + **config.characters if config.characters is not None else {}, + ) + if __name__ == "__main__": - print(" > TTS symbols {}".format(len(symbols))) - print(symbols) - print(" > TTS phonemes {}".format(len(phonemes))) - print("".join(sorted(phonemes))) + gr = Graphemes() + ph = IPAPhonemes() + + print(gr.vocab) + print(ph.vocab) + + print(gr.num_chars) + assert "a" == gr.id_to_char(gr.char_to_id("a")) From 0344645e90149f34164ea96c9d53f1007a3c8d6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:27:42 +0100 Subject: [PATCH 028/214] Implement TTSTokenizer --- TTS/tts/utils/text/tokenizer.py | 120 ++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 TTS/tts/utils/text/tokenizer.py diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py new file mode 100644 index 00000000..f6803ff6 --- /dev/null +++ b/TTS/tts/utils/text/tokenizer.py @@ -0,0 +1,120 @@ +from typing import Callable, Dict, List, Union + +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.tts.utils.text.symbols import Graphemes, IPAPhonemes + + +class TTSTokenizer: + """🐸TTS tokenizer to convert input characters to token IDs and back. + + Args: + use_phonemes (bool): + Whether to use phonemes instead of characters. Defaults to False. + + characters (Characters): + A Characters object to use for character-to-ID and ID-to-character mappings. + + text_cleaner (callable): + A function to pre-process the text before tokenization and phonemization. Defaults to None. + + phonemizer (Phonemizer): + A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + + """ + + def __init__( + self, + use_phonemes=False, + text_cleaner: Callable = None, + characters: "BaseCharacters" = None, + phonemizer: Union["Phonemizer", Dict] = None, + add_blank: bool = False, + use_eos_bos=False, + ): + self.text_cleaner = text_cleaner or (lambda x: x) + self.use_phonemes = use_phonemes + self.add_blank = add_blank + self.use_eos_bos = use_eos_bos + self.characters = characters + self.phonemizer = phonemizer + + def encode(self, text: str) -> List[int]: + """Encodes a string of text as a sequence of IDs.""" + token_ids = [] + for char in text: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + return token_ids + + def decode(self, token_ids: List[int]) -> str: + """Decodes a sequence of IDs to a string of text.""" + text = "" + for token_id in token_ids: + text += self.characters.id_to_char(token_id) + return text + + def text_to_ids(self, text: str, language: str = None) -> List[int]: + """Converts a string of text to a sequence of token IDs. + + Args: + text(str): + The text to convert to token IDs. + + language(str): + The language code of the text. Defaults to None. + + 1. Text normalizatin + 2. Phonemization (if use_phonemes is True) + 3. Add blank char between characters + 4. Add BOS and EOS characters + 5. Text to token IDs + """ + # TODO: text cleaner should pick the right routine based on the language + text = self.text_cleaner(text) + if self.use_phonemes: + text = self.phonemizer.phonemize(text, separator="") + if self.add_blank: + text = self.intersperse_blank_char(text, True) + if self.use_eos_bos: + text = self.pad_with_bos_eos(text) + return self.encode(text) + + def ids_to_text(self, id_sequence: List[int]) -> str: + """Converts a sequence of token IDs to a string of text.""" + return self.decode(id_sequence) + + def pad_with_bos_eos(self, char_sequence: List[str]): + """Pads a sequence with the special BOS and EOS characters.""" + return [self.characters.bos] + list(char_sequence) + [self.characters.eos] + + def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + char_to_use = self.characters.blank_char if use_blank_char else self.characters.pad + result = [char_to_use] * (len(char_sequence) * 2 + 1) + result[1::2] = char_sequence + return result + + def print_logs(self, level: int = 1): + indent = "\t" * level + print(f"{indent}| > add_blank: {self.use_phonemes}") + print(f"{indent}| > use_eos_bos: {self.use_phonemes}") + print(f"{indent}| > use_phonemes: {self.use_phonemes}") + print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}") + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init Tokenizer object from the config. + + Args: + config (Coqpit): Coqpit model config. + """ + if isinstance(config.text_cleaner, (str, list)): + text_cleaner = getattr(cleaners, config.text_cleaner) + + if config.use_phonemes: + characters = IPAPhonemes().init_from_config(config) + phonemizer_kwargs = {"language": config.phoneme_language} + phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs) + else: + characters = Graphemes().init_from_config(config) + return TTSTokenizer(config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars) \ No newline at end of file From 1aca58afafb3eedeb9563955dea1e7de13cd7790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:28:15 +0100 Subject: [PATCH 029/214] Fix imports in cleaners.py --- TTS/tts/utils/text/cleaners.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index f3ffa478..0ff3e930 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -1,12 +1,16 @@ +"""Set of default text cleaners""" +# TODO: pick the cleaner for languages dynamically + import re from anyascii import anyascii from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text -from .abbreviations import abbreviations_en, abbreviations_fr -from .number_norm import normalize_numbers -from .time import expand_time_english +from .english.abbreviations import abbreviations_en +from .english.number_norm import normalize_numbers as en_normalize_numbers +from .english.time_norm import expand_time_english +from .french.abbreviations import abbreviations_fr # Regular expression matching whitespace: _whitespace_re = re.compile(r"\s+") @@ -22,10 +26,6 @@ def expand_abbreviations(text, lang="en"): return text -def expand_numbers(text): - return normalize_numbers(text) - - def lowercase(text): return text.lower() @@ -92,7 +92,7 @@ def english_cleaners(text): # text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) - text = expand_numbers(text) + text = en_normalize_numbers(text) text = expand_abbreviations(text) text = replace_symbols(text) text = remove_aux_symbols(text) @@ -128,7 +128,7 @@ def chinese_mandarin_cleaners(text: str) -> str: def phoneme_cleaners(text): """Pipeline for phonemes mode, including number and abbreviation expansion.""" - text = expand_numbers(text) + text = en_normalize_numbers(text) # text = convert_to_ascii(text) text = expand_abbreviations(text) text = replace_symbols(text) From 8d85af84cd5f1748f979fddcbc4aab1449f61ecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:28:42 +0100 Subject: [PATCH 030/214] Implement Punctuation class --- TTS/tts/utils/text/punctuation.py | 156 ++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 TTS/tts/utils/text/punctuation.py diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py new file mode 100644 index 00000000..624cea88 --- /dev/null +++ b/TTS/tts/utils/text/punctuation.py @@ -0,0 +1,156 @@ +import collections +import re +from enum import Enum + +import six + +_DEF_PUNCS = ';:,.!?¡¿—…"«»“”' + +_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"]) + + +class PuncPosition(Enum): + """Enum for the punctuations positions""" + + BEGIN = 0 + END = 1 + MIDDLE = 2 + ALONE = 3 + + +class Punctuation: + """Handle punctuations characters in text. + + Just strip punctuations from text or strip and restore them later. + + Args: + puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + """ + + def __init__(self, puncs: str = _DEF_PUNCS): + self.puncs = puncs + + @staticmethod + def default_puncs(): + """Return default set of punctuations.""" + return _DEF_PUNCS + + @property + def puncs(self): + return self._puncs + + @puncs.setter + def puncs(self, value): + if not isinstance(value, six.string_types): + raise ValueError("[!] Punctuations must be of type str.") + self._puncs = "".join(set(value)) + self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") + + def strip(self, text): + """Remove all the punctuations by replacing with `space`. + + Args: + text (str): The text to be processed. + + Example:: + + "This is. example !" -> "This is example " + """ + return re.sub(self.puncs_regular_exp, " ", text).strip() + + def strip_to_restore(self, text): + """Remove punctuations from text to restore them later. + + Args: + text (str): The text to be processed. + + Examples :: + + "This is. example !" -> [["This is", "example"], [".", "!"]] + + """ + text, puncs = self._strip_to_restore(text) + return text, puncs + + def _strip_to_restore(self, text): + """Auxiliary method for Punctuation.preserve()""" + matches = list(re.finditer(self.puncs_regular_exp, text)) + if not matches: + return [text], [] + # the text is only punctuations + if len(matches) == 1 and matches[0].group() == text: + return [], [_PUNC_IDX(text, PuncPosition.ALONE)] + # build a punctuation map to be used later to restore punctuations + puncs = [] + for match in matches: + position = PuncPosition.MIDDLE + if match == matches[0] and text.startswith(match.group()): + position = PuncPosition.BEGIN + elif match == matches[-1] and text.endswith(match.group()): + position = PuncPosition.END + puncs.append(_PUNC_IDX(match.group(), position)) + # convert str text to a List[str], each item is separated by a punctuation + splitted_text = [] + for punc in puncs: + split = text.split(punc.punc) + prefix, suffix = split[0], punc.punc.join(split[1:]) + splitted_text.append(prefix) + text = suffix + return splitted_text, puncs + + @classmethod + def restore(cls, text, puncs): + """Restore punctuation in a text. + + Args: + text (str): The text to be processed. + puncs (List[str]): The list of punctuations map to be used for restoring. + + Examples :: + + ['This is', 'example'], ['.', '!'] -> "This is. example!" + + """ + return cls._restore(text, puncs, 0) + + @classmethod + def _restore(cls, text, puncs, num): + """Auxiliary method for Punctuation.restore()""" + if not puncs: + return text + + # nothing have been phonemized, returns the puncs alone + if not text: + return ["".join(m.mark for m in puncs)] + + current = puncs[0] + + if current.position == PuncPosition.BEGIN: + return cls._restore([current.mark + text[0]] + text[1:], puncs[1:], num) + + if current.position == PuncPosition.END: + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) + + if current.position == PuncPosition.ALONE: + return [current.mark] + cls._restore(text, puncs[1:], num + 1) + + # POSITION == MIDDLE + if len(text) == 1: # pragma: nocover + # a corner case where the final part of an intermediate + # mark (I) has not been phonemized + return cls._restore([text[0] + current.punc], puncs[1:], num) + + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) + + +if __name__ == "__main__": + punc = Punctuation() + text = "This is. This is, example!" + + print(punc.strip(text)) + + split_text, puncs = punc.strip_to_restore(text) + print(split_text, " ---- ", puncs) + + restored_text = punc.restore(split_text, puncs) + print(restored_text) From 3d86edfc81a150c7f13cd044685106411ed50700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:29:57 +0100 Subject: [PATCH 031/214] Refactor Synthesizer class for TTSTokenizer --- TTS/utils/synthesizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index fc45e7fa..12b71ab6 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -13,6 +13,7 @@ from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, trim_silence +from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input @@ -114,6 +115,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) + self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -332,6 +334,7 @@ class Synthesizer(object): CONFIG=self.tts_config, use_cuda=self.use_cuda, ap=self.ap, + tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, From 53f696615bf6a9fac0ebff27a56232752bd5a6a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:30:37 +0100 Subject: [PATCH 032/214] Add init_from_config to AudioProcessor --- TTS/utils/audio.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 0253f918..ee255f44 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -379,6 +379,10 @@ class AudioProcessor(object): self.clip_norm = None self.symmetric_norm = None + @staticmethod + def init_from_config(config: "Coqpit"): + return AudioProcessor(**config.audio) + ### setting up the parameters ### def _build_mel_basis( self, From 2480bbe937a12b9400a24b9fad8d05e902ef044c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:31:18 +0100 Subject: [PATCH 033/214] Remove OLD TOKENIZATION ROUTINES --- TTS/tts/utils/text/__init__.py | 277 +-------------------------------- 1 file changed, 1 insertion(+), 276 deletions(-) diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 537d2301..593372dc 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -1,276 +1 @@ -# -*- coding: utf-8 -*- -# adapted from https://github.com/keithito/tacotron - -import re -from typing import Dict, List - -import gruut -from gruut_ipa import IPA - -from TTS.tts.utils.text import cleaners -from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes -from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes -from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols - -# pylint: disable=unnecessary-comprehension -# Mappings from symbol to numeric ID and vice versa: -_symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = {i: s for i, s in enumerate(symbols)} - -_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} -_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} - -_symbols = symbols -_phonemes = phonemes - -# Regular expression matching text enclosed in curly braces: -_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") - -# Regular expression matching punctuations, ignoring empty space -PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+" - -# Table for str.translate to fix gruut/TTS phoneme mismatch -GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") - - -def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False): - """Convert graphemes to phonemes. - Parameters: - text (str): text to phonemize - language (str): language of the text - Returns: - ph (str): phonemes as a string seperated by "|" - ph = "ɪ|g|ˈ|z|æ|m|p|ə|l" - """ - - # TO REVIEW : How to have a good implementation for this? - if language == "zh-CN": - ph = chinese_text_to_phonemes(text) - return ph - - if language == "ja-jp": - ph = japanese_text_to_phonemes(text) - return ph - - if not gruut.is_language_supported(language): - raise ValueError(f" [!] Language {language} is not supported for phonemization.") - - # Use gruut for phonemization - ph_list = [] - for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes): - for word in sentence: - if word.is_break: - # Use actual character for break phoneme (e.g., comma) - if ph_list: - # Join with previous word - ph_list[-1].append(word.text) - else: - # First word is punctuation - ph_list.append([word.text]) - elif word.phonemes: - # Add phonemes for word - word_phonemes = [] - - for word_phoneme in word.phonemes: - if not keep_stress: - # Remove primary/secondary stress - word_phoneme = IPA.without_stress(word_phoneme) - - word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) - - if word_phoneme: - # Flatten phonemes - word_phonemes.extend(word_phoneme) - - if word_phonemes: - ph_list.append(word_phonemes) - - # Join and re-split to break apart dipthongs, suprasegmentals, etc. - ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list] - ph = "| ".join(ph_words) - - return ph - - -def intersperse(sequence, token): - result = [token] * (len(sequence) * 2 + 1) - result[1::2] = sequence - return result - - -def pad_with_eos_bos(phoneme_sequence, tp=None): - # pylint: disable=global-statement - global _phonemes_to_id, _bos, _eos - if tp: - _bos = tp["bos"] - _eos = tp["eos"] - _, _phonemes = make_symbols(**tp) - _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} - - return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] - - -def phoneme_to_sequence( - text: str, - cleaner_names: List[str], - language: str, - enable_eos_bos: bool = False, - custom_symbols: List[str] = None, - tp: Dict = None, - add_blank: bool = False, - use_espeak_phonemes: bool = False, -) -> List[int]: - """Converts a string of phonemes to a sequence of IDs. - If `custom_symbols` is provided, it will override the default symbols. - - Args: - text (str): string to convert to a sequence - cleaner_names (List[str]): names of the cleaner functions to run the text through - language (str): text language key for phonemization. - enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens. - tp (Dict): dictionary of character parameters to use a custom character set. - add_blank (bool): option to add a blank token between each token. - use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc - - Returns: - List[int]: List of integers corresponding to the symbols in the text - """ - # pylint: disable=global-statement - global _phonemes_to_id, _phonemes - - if custom_symbols is not None: - _phonemes = custom_symbols - elif tp: - _, _phonemes = make_symbols(**tp) - _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} - - sequence = [] - clean_text = _clean_text(text, cleaner_names) - to_phonemes = text2phone(clean_text, language, use_espeak_phonemes=use_espeak_phonemes) - if to_phonemes is None: - print("!! After phoneme conversion the result is None. -- {} ".format(clean_text)) - # iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation. - for phoneme in filter(None, to_phonemes.split("|")): - sequence += _phoneme_to_sequence(phoneme) - # Append EOS char - if enable_eos_bos: - sequence = pad_with_eos_bos(sequence, tp=tp) - if add_blank: - sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) - return sequence - - -def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None): - # pylint: disable=global-statement - """Converts a sequence of IDs back to a string""" - global _id_to_phonemes, _phonemes - if add_blank: - sequence = list(filter(lambda x: x != len(_phonemes), sequence)) - result = "" - - if custom_symbols is not None: - _phonemes = custom_symbols - elif tp: - _, _phonemes = make_symbols(**tp) - _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} - - for symbol_id in sequence: - if symbol_id in _id_to_phonemes: - s = _id_to_phonemes[symbol_id] - result += s - return result.replace("}{", " ") - - -def text_to_sequence( - text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False -) -> List[int]: - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - If `custom_symbols` is provided, it will override the default symbols. - - Args: - text (str): string to convert to a sequence - cleaner_names (List[str]): names of the cleaner functions to run the text through - tp (Dict): dictionary of character parameters to use a custom character set. - add_blank (bool): option to add a blank token between each token. - - Returns: - List[int]: List of integers corresponding to the symbols in the text - """ - # pylint: disable=global-statement - global _symbol_to_id, _symbols - - if custom_symbols is not None: - _symbols = custom_symbols - elif tp: - _symbols, _ = make_symbols(**tp) - _symbol_to_id = {s: i for i, s in enumerate(_symbols)} - - sequence = [] - - # Check for curly braces and treat their contents as ARPAbet: - while text: - m = _CURLY_RE.match(text) - if not m: - sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) - break - sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) - - if add_blank: - sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) - return sequence - - -def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None): - """Converts a sequence of IDs back to a string""" - # pylint: disable=global-statement - global _id_to_symbol, _symbols - if add_blank: - sequence = list(filter(lambda x: x != len(_symbols), sequence)) - - if custom_symbols is not None: - _symbols = custom_symbols - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - elif tp: - _symbols, _ = make_symbols(**tp) - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - - result = "" - for symbol_id in sequence: - if symbol_id in _id_to_symbol: - s = _id_to_symbol[symbol_id] - # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == "@": - s = "{%s}" % s[1:] - result += s - return result.replace("}{", " ") - - -def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception("Unknown cleaner: %s" % name) - text = cleaner(text) - return text - - -def _symbols_to_sequence(syms): - return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)] - - -def _phoneme_to_sequence(phons): - return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)] - - -def _arpabet_to_sequence(text): - return _symbols_to_sequence(["@" + s for s in text.split()]) - - -def _should_keep_symbol(s): - return s in _symbol_to_id and s not in ["~", "^", "_"] - - -def _should_keep_phoneme(p): - return p in _phonemes_to_id and p not in ["~", "^", "_"] +from TTS.tts.utils.text.tokenizer import TTSTokenizer From e4049aa31a0a27e49267613e76806ff4df4f23c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:33:21 +0100 Subject: [PATCH 034/214] Refactor TTSDataset to use TTSTokenizer --- TTS/tts/datasets/dataset.py | 113 +++++++++--------------------------- 1 file changed, 28 insertions(+), 85 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 546f012d..8c21d7d0 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -10,7 +10,7 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence +from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -18,23 +18,17 @@ class TTSDataset(Dataset): def __init__( self, outputs_per_step: int, - text_cleaner: list, compute_linear_spec: bool, ap: AudioProcessor, meta_data: List[Dict], + tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, - characters: Dict = None, - custom_symbols: List = None, - add_blank: bool = False, return_wav: bool = False, batch_group_size: int = 0, min_seq_len: int = 0, max_seq_len: int = float("inf"), - use_phonemes: bool = False, phoneme_cache_path: str = None, - phoneme_language: str = "en-us", - enable_eos_bos: bool = False, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, @@ -48,26 +42,19 @@ class TTSDataset(Dataset): Args: outputs_per_step (int): Number of time frames predicted per step. - text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs. - compute_linear_spec (bool): compute linear spectrogram if True. ap (TTS.tts.utils.AudioProcessor): Audio processor object. meta_data (list): List of dataset samples. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + compute_f0 (bool): compute f0 if True. Defaults to False. f0_cache_path (str): Path to store f0 cache. Defaults to None. - characters (dict): `dict` of custom text characters used for converting texts to sequences. - - custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own - set of symbols need to pass it here. Defaults to `None`. - - add_blank (bool): Add a special `blank` character after every other character. It helps some - models achieve better results. Defaults to false. - return_wav (bool): Return the waveform of the sample. Defaults to False. batch_group_size (int): Range of batch randomization after sorting @@ -82,16 +69,9 @@ class TTSDataset(Dataset): It helps for controlling the VRAM usage against long input sequences. Especially models with RNN layers are sensitive to input length. Defaults to `Inf`. - use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. - phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. - phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. - - enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults - to False. - speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. @@ -106,7 +86,6 @@ class TTSDataset(Dataset): self.items = meta_data self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate - self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 @@ -114,13 +93,7 @@ class TTSDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap - self.characters = characters - self.custom_symbols = custom_symbols - self.add_blank = add_blank - self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path - self.phoneme_language = phoneme_language - self.enable_eos_bos = enable_eos_bos self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping @@ -130,17 +103,23 @@ class TTSDataset(Dataset): self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False + self.tokenizer = tokenizer - if use_phonemes and not os.path.isdir(phoneme_cache_path): + if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if compute_f0: self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) + if self.verbose: - print("\n > DataLoader initialization") - print(" | > Use phonemes: {}".format(self.use_phonemes)) - if use_phonemes: - print(" | > phoneme language: {}".format(phoneme_language)) - print(" | > Number of instances : {}".format(len(self.items))) + self.print_logs() + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.items)}") def load_wav(self, filename): audio = self.ap.load_wav(filename) @@ -152,48 +131,30 @@ class TTSDataset(Dataset): return data @staticmethod - def _generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ): + def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" - phonemes = phoneme_to_sequence( - text, - [cleaners], - language=language, - enable_eos_bos=False, - custom_symbols=custom_symbols, - tp=characters, - add_blank=add_blank, - ) + phonemes = tokenizer.text_to_ids(text) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @staticmethod - def _load_or_generate_phoneme_sequence( - wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank - ): + def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] # different names for normal phonemes and with blank chars. - file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" + file_name_ext = "_phoneme.npy" cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) try: phonemes = np.load(cache_path) except FileNotFoundError: - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) + phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) except (ValueError, IOError): print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) - if enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes, tp=characters) + phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -208,27 +169,17 @@ class TTSDataset(Dataset): wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) if not self.input_seq_computed: - if self.use_phonemes: + if self.tokenizer.use_phonemes: text = self._load_or_generate_phoneme_sequence( item["audio_file"], item["text"], - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, item["language"] if item["language"] else self.phoneme_language, - self.custom_symbols, - self.characters, - self.add_blank, + self.tokenizer, + self.phoneme_cache_path, ) else: text = np.asarray( - text_to_sequence( - item["text"], - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), + self.tokenizer.text_to_ids(item["text"], item["language"]), dtype=np.int32, ) @@ -281,24 +232,16 @@ class TTSDataset(Dataset): print(" | > Computing input sequences ...") for idx, item in enumerate(tqdm.tqdm(self.items)): sequence = np.asarray( - text_to_sequence( - item["text"], - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), + self.tokenizer.text_to_ids(item["text"]), dtype=np.int32, ) self.items[idx][0] = sequence - else: func_args = [ self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.custom_symbols, self.characters, self.add_blank, ] From e5785b34b07c6c923ef5c7a3b67c69df8d256b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:34:17 +0100 Subject: [PATCH 035/214] Style fix --- TTS/tts/utils/visual.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index ff71958e..de6d95c5 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -4,8 +4,6 @@ import matplotlib.pyplot as plt import numpy as np import torch -from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme - matplotlib.use("Agg") @@ -95,6 +93,7 @@ def visualize( text, hop_length, CONFIG, + tokenizer, stop_tokens=None, decoder_output=None, output_path=None, @@ -117,14 +116,8 @@ def visualize( plt.ylabel("Encoder timestamp", fontsize=label_fontsize) # compute phoneme representation and back if CONFIG.use_phonemes: - seq = phoneme_to_sequence( - text, - [CONFIG.text_cleaner], - CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if "characters" in CONFIG.keys() else None, - ) - text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None) + seq = tokenizer.text_to_ids(text) + text = tokenizer.ids_to_text(seq) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() From 5a9653978a412d26b02aed74df280190af66a15c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:34:45 +0100 Subject: [PATCH 036/214] Refactor synthesis.py for TTSTokenizer --- TTS/tts/utils/synthesis.py | 53 +++++++++----------------------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index b2ea4208..9a34e5d4 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -4,34 +4,6 @@ import numpy as np import torch from torch import nn -from .text import phoneme_to_sequence, text_to_sequence - - -def text_to_seq(text, CONFIG, custom_symbols=None, language=None): - text_cleaner = [CONFIG.text_cleaner] - # text ot phonemes to sequence vector - if CONFIG.use_phonemes: - seq = np.asarray( - phoneme_to_sequence( - text, - text_cleaner, - language if language else CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters, - add_blank=CONFIG.add_blank, - use_espeak_phonemes=CONFIG.use_espeak_phonemes, - custom_symbols=custom_symbols, - ), - dtype=np.int32, - ) - else: - seq = np.asarray( - text_to_sequence( - text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols - ), - dtype=np.int32, - ) - return seq def numpy_to_torch(np_array, dtype, cuda=False): @@ -143,9 +115,9 @@ def synthesis( CONFIG, use_cuda, ap, + tokenizer, speaker_id=None, style_wav=None, - enable_eos_bos_chars=False, # pylint: disable=unused-argument use_griffin_lim=False, do_trim_silence=False, d_vector=None, @@ -194,17 +166,17 @@ def synthesis( """ # GST processing style_mel = None - custom_symbols = None - if style_wav: - style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) - elif CONFIG.has("gst") and CONFIG.gst and not style_wav: - if CONFIG.gst.gst_style_input_weights: - style_mel = CONFIG.gst.gst_style_input_weights - if hasattr(model, "make_symbols"): - custom_symbols = model.make_symbols(CONFIG) - # preprocess the given text - text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name) - + if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + # convert text to sequence of token IDs + text_inputs = np.asarray( + tokenizer.text_to_ids(text), + dtype=np.int32, + ) + # pass tensors to backend if speaker_id is not None: speaker_id = id_to_torch(speaker_id, cuda=use_cuda) @@ -218,7 +190,6 @@ def synthesis( style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = text_inputs.unsqueeze(0) - # synthesize voice outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) model_outputs = outputs["model_outputs"] From bd461ace337b7df9b750d1bc5a81e33c253c6429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:36:35 +0100 Subject: [PATCH 037/214] Refactor GlowTTS model and recipe for TTSTokenizer --- TTS/model.py | 9 ++-- TTS/tts/models/base_tts.py | 62 ++++++++++++---------- TTS/tts/models/glow_tts.py | 12 ++--- recipes/ljspeech/glow_tts/train_glowtts.py | 12 +++-- 4 files changed, 54 insertions(+), 41 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index 532d05a6..a7c64dde 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC): def __init__(self, config: Coqpit): super().__init__() - self._set_model_args(config) - def _set_model_args(self, config: Coqpit): - """Set model arguments from the config. Override this.""" + @staticmethod + def init_from_config(config: Coqpit): + """Init the model from given config. + + Override this depending on your model. + """ pass @abstractmethod diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index e52cd765..45cae79e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text import make_symbols +from TTS.tts.utils.text.symbols import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -34,8 +34,20 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ + def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + super().__init__(config) + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self._set_model_args(config) + def _set_model_args(self, config: Coqpit): - """Setup model args based on the config type. + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. If the config is for training with a name like "*Config", then the model args are embeded in the config.model_args @@ -44,8 +56,8 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: + num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars if "characters" in config: - _, self.config, num_chars = self.get_characters(config) self.config.num_chars = num_chars if hasattr(self.config, "model_args"): config.model_args.num_chars = num_chars @@ -58,18 +70,21 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - @staticmethod - def get_characters(config: Coqpit) -> str: - # TODO: implement CharacterProcessor - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols + # @staticmethod + # def get_characters(config: Coqpit) -> str: + # # TODO: implement CharacterProcessor + # if config.characters is not None: + # symbols, phonemes = make_symbols(**config.characters) + # else: + # from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols - config.characters = CharactersConfig(**parse_symbols()) - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) + getattr(config, "add_blank", False) - return model_characters, config, num_chars + # if config.use_phonemes: + + # config.characters = Graphemes() + + # model_characters = phonemes if config.use_phonemes else symbols + # num_chars = len(model_characters) + getattr(config, "add_blank", False) + # return model_characters, config, num_chars def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) @@ -247,8 +262,6 @@ class BaseTTS(BaseModel): if is_eval and not config.run_eval: loader = None else: - ap = assets["audio_processor"] - # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): @@ -279,28 +292,21 @@ class BaseTTS(BaseModel): # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, - text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, - ap=ap, - characters=config.characters, - custom_symbols=custom_symbols, - add_blank=config["add_blank"], + ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_seq_len=config.min_seq_len, max_seq_len=config.max_seq_len, phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping, - language_id_mapping=language_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer ) # pre-compute phonemes @@ -332,7 +338,7 @@ class BaseTTS(BaseModel): if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): dataset.pitch_extractor.compute_pitch( - ap, config.get("f0_cache_path", None), config.num_loader_workers + self.ap, config.get("f0_cache_path", None), config.num_loader_workers ) # halt DDP processes for the main process to finish computing the F0 cache @@ -404,6 +410,7 @@ class BaseTTS(BaseModel): Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ ap = assets["audio_processor"] + tokenizer = assets["tokenizer"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -416,6 +423,7 @@ class BaseTTS(BaseModel): self.config, "cuda" in str(next(self.parameters()).device), ap, + tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 8f3b3804..907f3846 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -46,11 +46,9 @@ class GlowTTS(BaseTTS): """ - def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): + def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): - super().__init__(config) - - self.speaker_manager = speaker_manager + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields to `self` # for fewer code change @@ -58,7 +56,7 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - _, self.config, self.num_chars = self.get_characters(config) + self.num_chars = self.tokenizer.characters.num_chars self.decoder_output_dim = config.out_channels # init multi-speaker layers if necessary @@ -448,7 +446,6 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -463,7 +460,8 @@ class GlowTTS(BaseTTS): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, + self.ap, + self.tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 7bd9ea19..fe4a9d9b 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # we use the same path as this script as our training folder. @@ -47,7 +48,11 @@ config = GlowTTSConfig( # INITIALIZE THE AUDIO PROCESSOR # Audio processor is used for feature extraction and audio I/O. # It mainly serves to the dataloader and the training loggers. -ap = AudioProcessor(**config.audio.to_dict()) +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +tokenizer = TTSTokenizer.init_from_config(config) # LOAD DATA SAMPLES # Each sample is a list of ```[text, audio_file_path, speaker_name]``` @@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # Models take a config object and a speaker manager as input # Config defines the details of the model like the number of layers, the size of the embedding, etc. # Speaker manager is used by multi-speaker models. -model = GlowTTS(config, speaker_manager=None) +model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # INITIALIZE THE TRAINER # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, @@ -71,8 +76,7 @@ trainer = Trainer( output_path, model=model, train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members. + eval_samples=eval_samples ) # AND... 3,2,1... 🚀 From a1df4f98875409404289e0ed71b17becd8b377ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 17 Nov 2021 12:43:45 +0100 Subject: [PATCH 038/214] Test character classes --- TTS/tts/utils/text/characters.py | 344 ++++++++++++++++++++++++++++ tests/text_tests/test_characters.py | 127 ++++++++++ tests/text_tests/test_symbols.py | 8 - 3 files changed, 471 insertions(+), 8 deletions(-) create mode 100644 TTS/tts/utils/text/characters.py create mode 100644 tests/text_tests/test_characters.py delete mode 100644 tests/text_tests/test_symbols.py diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py new file mode 100644 index 00000000..05882e57 --- /dev/null +++ b/TTS/tts/utils/text/characters.py @@ -0,0 +1,344 @@ +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + +# DEFAULT SET OF GRAPHEMES +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_punctuations = "!'(),-.:;? " + + +# DEFAULT SET OF IPA PHONEMES +# Phonemes definition (All IPA characters) +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics + + +def create_graphemes( + characters=_characters, + punctuations=_punctuations, + pad=_pad, + eos=_eos, + bos=_bos, + blank=_blank, + unique=True, +): # pylint: disable=redefined-outer-name + """Function to create default characters and phonemes""" + # create graphemes + _graphemes = list(characters) + _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes + _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes + _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes + _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes + _graphemes = _graphemes + list(punctuations) + return _graphemes, _phonemes + + +def create_phonemes( + phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True +): + # create phonemes + _phonemes = None + _phonemes_sorted = ( + sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) + ) # this is to keep previous models compatible. + _phonemes = list(_phonemes_sorted) + _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes + _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes + _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes + _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes + _phonemes = _phonemes + list(punctuations) + _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) + return _phonemes + + +graphemes = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos) +phonemes = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) + + +class BaseCharacters: + """🐸BaseCharacters class + + Every new character class should inherit from this. + + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. +el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + """ + + def __init__( + self, + characters: str, + punctuations: str, + pad: str, + eos: str, + bos: str, + blank: str, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._create_vocab() + + @property + def blank(self): + return self._blank + + @blank.setter + def blank(self, blank): + self._blank = blank + self._create_vocab() + + @property + def vocab(self): + return self._vocab + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self._characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + self._vocab = _vocab + list(self._punctuations) + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + if self.is_unique: + assert ( + len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + ), f" [!] There are duplicate characters in the character set." + + def char_to_id(self, char: str) -> int: + return self._char_to_id[char] + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + def print_log(self, level:int=0): + """ + Prints the vocabulary in a nice format. + """ + indent = "\t" * level + print(f"{indent}| > Characters: {self._characters}") + print(f"{indent}| > Punctuations: {self._punctuations}") + print(f"{indent}| > Pad: {self._pad}") + print(f"{indent}| > EOS: {self._eos}") + print(f"{indent}| > BOS: {self._bos}") + print(f"{indent}| > Blank: {self._blank}") + print(f"{indent}| > Vocab: {self.vocab}") + print(f"{indent}| > Num chars: {self.num_chars}") + + @staticmethod + def init_from_config(config: "Coqpit"): + return BaseCharacters( + **config.characters if config.characters is not None else {}, + ) + + +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + return IPAPhonemes( + **config.characters if config.characters is not None else {}, + ) + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = True, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + return Graphemes( + **config.characters if config.characters is not None else {}, + ) + + +if __name__ == "__main__": + gr = Graphemes() + ph = IPAPhonemes() + + print(gr.vocab) + print(ph.vocab) + + print(gr.num_chars) + assert "a" == gr.id_to_char(gr.char_to_id("a")) diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py new file mode 100644 index 00000000..5a051ac4 --- /dev/null +++ b/tests/text_tests/test_characters.py @@ -0,0 +1,127 @@ +import unittest + +from TTS.tts.utils.text.characters import ( + BaseCharacters, + IPAPhonemes, + Graphemes, + create_graphemes, + create_phonemes, +) + + +def test_make_symbols(): + _ = create_phonemes() + _ = create_graphemes() + + +class BaseCharacterTest(unittest.TestCase): + def setUp(self): + self.characters_empty = BaseCharacters( + "", + "", + pad="", + eos="", + bos="", + blank="", + is_unique=True, + is_sorted=True + ) + + def test_default_character_sets(self): + """Test initiation of default character sets""" + _ = IPAPhonemes() + _ = Graphemes() + + def test_unique(self): + """Test if the unique option works""" + self.characters_empty.characters = "abcc" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) + + + def test_unique_sorted(self): + """Test if the unique and sorted option works""" + self.characters_empty.characters = "cba" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) + + def test_setters_getters(self): + """Test the class setters behaves as expected""" + self.characters_empty.characters = "abc" + self.assertEqual(self.characters_empty._characters, "abc") + self.assertEqual(self.characters_empty.vocab, ["a", "b", "c"]) + + self.characters_empty.punctuations = ".,;:!? " + self.assertEqual(self.characters_empty._punctuations, ".,;:!? ") + self.assertEqual(self.characters_empty.vocab, ["a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.pad = "[PAD]" + self.assertEqual(self.characters_empty._pad, "[PAD]") + self.assertEqual(self.characters_empty.vocab, ["[PAD]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.eos = "[EOS]" + self.assertEqual(self.characters_empty._eos, "[EOS]") + self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.bos = "[BOS]" + self.assertEqual(self.characters_empty._bos, "[BOS]") + self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.blank = "[BLANK]" + self.assertEqual(self.characters_empty._blank, "[BLANK]") + self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) + + self.characters_empty.print_log() + + def test_char_lookup(self): + """Test char to ID and ID to char conversion""" + self.characters_empty.characters = "abc" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + # char to ID + self.assertEqual(self.characters_empty.char_to_id("[PAD]"), 0) + self.assertEqual(self.characters_empty.char_to_id("[EOS]"), 1) + self.assertEqual(self.characters_empty.char_to_id("[BOS]"), 2) + self.assertEqual(self.characters_empty.char_to_id("[BLANK]"), 3) + self.assertEqual(self.characters_empty.char_to_id("a"), 4) + self.assertEqual(self.characters_empty.char_to_id("b"), 5) + self.assertEqual(self.characters_empty.char_to_id("c"), 6) + self.assertEqual(self.characters_empty.char_to_id("."), 7) + self.assertEqual(self.characters_empty.char_to_id(","), 8) + self.assertEqual(self.characters_empty.char_to_id(";"), 9) + self.assertEqual(self.characters_empty.char_to_id(":"), 10) + self.assertEqual(self.characters_empty.char_to_id("!"), 11) + self.assertEqual(self.characters_empty.char_to_id("?"), 12) + self.assertEqual(self.characters_empty.char_to_id(" "), 13) + + # ID to char + self.assertEqual(self.characters_empty.id_to_char(0), "[PAD]") + self.assertEqual(self.characters_empty.id_to_char(1), "[EOS]") + self.assertEqual(self.characters_empty.id_to_char(2), "[BOS]") + self.assertEqual(self.characters_empty.id_to_char(3), "[BLANK]") + self.assertEqual(self.characters_empty.id_to_char(4), "a") + self.assertEqual(self.characters_empty.id_to_char(5), "b") + self.assertEqual(self.characters_empty.id_to_char(6), "c") + self.assertEqual(self.characters_empty.id_to_char(7), ".") + self.assertEqual(self.characters_empty.id_to_char(8), ",") + self.assertEqual(self.characters_empty.id_to_char(9), ";") + self.assertEqual(self.characters_empty.id_to_char(10), ":") + self.assertEqual(self.characters_empty.id_to_char(11), "!") + self.assertEqual(self.characters_empty.id_to_char(12), "?") + self.assertEqual(self.characters_empty.id_to_char(13), " ") + diff --git a/tests/text_tests/test_symbols.py b/tests/text_tests/test_symbols.py deleted file mode 100644 index 49b25986..00000000 --- a/tests/text_tests/test_symbols.py +++ /dev/null @@ -1,8 +0,0 @@ -import unittest - -from TTS.tts.utils.text import phonemes - - -class SymbolsTest(unittest.TestCase): - def test_uniqueness(self): # pylint: disable=no-self-use - assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes))) From fbad17e084909ee52a246bc7a238beb806d82252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 17 Nov 2021 12:46:04 +0100 Subject: [PATCH 039/214] Update imports for symbols -> characters --- TTS/bin/compute_attention_masks.py | 2 +- TTS/speaker_encoder/utils/training.py | 2 +- TTS/tts/models/__init__.py | 4 +- TTS/tts/models/base_tts.py | 14 +- TTS/tts/models/glow_tts.py | 8 +- TTS/tts/utils/synthesis.py | 1 - TTS/tts/utils/text/phonemizers/base.py | 3 +- TTS/tts/utils/text/symbols.py | 316 --------------------- TTS/tts/utils/text/tokenizer.py | 6 +- recipes/ljspeech/glow_tts/train_glowtts.py | 7 +- 10 files changed, 26 insertions(+), 337 deletions(-) delete mode 100644 TTS/tts/utils/text/symbols.py diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index fc8c6629..e58259a6 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -11,7 +11,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_checkpoint diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index a32f43bd..b202ebcd 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -4,7 +4,7 @@ from coqpit import Coqpit from TTS.config import load_config, register_config from TTS.trainer import TrainingArgs -from TTS.tts.utils.text.symbols import parse_symbols +from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files from TTS.utils.logging import init_dashboard_logger diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 4cc8b658..c8371106 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,4 +1,4 @@ -from TTS.tts.utils.text.symbols import make_symbols, parse_symbols +from TTS.tts.utils.text.characters import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module @@ -17,7 +17,7 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manag else: symbols, phonemes = make_symbols(**config.characters) else: - from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel + from TTS.tts.utils.text.characters import phonemes, symbols # pylint: disable=import-outside-toplevel if config.use_phonemes: symbols = phonemes diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 45cae79e..98f68742 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import Graphemes, make_symbols +from TTS.tts.utils.text.characters import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -34,7 +34,9 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ - def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + def __init__( + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None + ): super().__init__(config) self.config = config self.ap = ap @@ -56,7 +58,9 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: - num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) if "characters" in config: self.config.num_chars = num_chars if hasattr(self.config, "model_args"): @@ -76,7 +80,7 @@ class BaseTTS(BaseModel): # if config.characters is not None: # symbols, phonemes = make_symbols(**config.characters) # else: - # from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols + # from TTS.tts.utils.text.characters import parse_symbols, phonemes, symbols # if config.use_phonemes: @@ -306,7 +310,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer + tokenizer=self.tokenizer, ) # pre-compute phonemes diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 907f3846..9e779f8e 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -46,7 +46,13 @@ class GlowTTS(BaseTTS): """ - def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): super().__init__(config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 9a34e5d4..a4f4b0c8 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -5,7 +5,6 @@ import torch from torch import nn - def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index b370822c..a14a73bc 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -32,7 +32,6 @@ class BasePhonemizer(abc.ABC): self._keep_puncs = keep_puncs self._punctuator = Punctuation(punctuations) - def _init_language(self, language): """Language initialization @@ -130,7 +129,7 @@ class BasePhonemizer(abc.ABC): phonemized = self._phonemize_postprocess(phonemized, punctuations) return phonemized - def print_logs(self, level: int=0): + def print_logs(self, level: int = 0): indent = "\t" * level print(f"{indent}| > phoneme language: {self.language}") print(f"{indent}| > phoneme backend: {self.name()}") diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py deleted file mode 100644 index ce59031d..00000000 --- a/TTS/tts/utils/text/symbols.py +++ /dev/null @@ -1,316 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Defines the set of symbols used in text input to the model. - -The default is a set of ASCII characters that works well for English or text that has been run -through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -""" - - -def parse_symbols(): - return { - "pad": _pad, - "eos": _eos, - "bos": _bos, - "characters": _characters, - "punctuations": _punctuations, - "phonemes": _phonemes, - } - - -def make_symbols( - characters, - phonemes=None, - punctuations="!'(),-.:;? ", - pad="", - eos="", - bos="", - blank="", - unique=True, -): # pylint: disable=redefined-outer-name - """Function to create default characters and phonemes""" - _symbols = list(characters) - _symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols - _symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols - _symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols - _symbols = [blank] + _symbols if len(bos) > 0 and blank is not None else _symbols - _phonemes = None - if phonemes is not None: - _phonemes_sorted = ( - sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) - ) # this is to keep previous models compatible. - # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - # _arpabet = ["@" + s for s in _phonemes_sorted] - # Export all symbols: - _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) - # _symbols += _arpabet - return _symbols, _phonemes - - -_pad = "" -_eos = "" -_bos = "" -_blank = "" # TODO: check if we need this alongside with PAD -_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? " -_punctuations = "!'(),-.:;? " - -# Phonemes definition (All IPA characters) -_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" -_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" -_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" -_suprasegmentals = "ˈˌːˑ" -_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" -_diacrilics = "ɚ˞ɫ" -_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics - -symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) - - -class BaseCharacters: - """🐸BaseCharacters class - - Every vocabulary class should inherit from this class. - - Args: - characters (str): - Main set of characters to be used in the vocabulary. - - punctuations (str): - Characters to be treated as punctuation. - - pad (str): - Special padding character that would be ignored by the model. - - eos (str): - End of the sentence character. - - bos (str): - Beginning of the sentence character. - - blank (str): - Optional character used between characters by some models for better prosody. - - is_unique (bool): - Remove duplicates from the provided characters. Defaults to True. - - is_sorted (bool): - Sort the characters in alphabetical order. Defaults to True. - """ - - def __init__( - self, - characters: str, - punctuations: str, - pad: str, - eos: str, - bos: str, - blank: str, - is_unique: bool = True, - is_sorted: bool = True, - ) -> None: - self._characters = characters - self._punctuations = punctuations - self._pad = pad - self._eos = eos - self._bos = bos - self._blank = blank - self.is_unique = is_unique - self.is_sorted = is_sorted - self._create_vocab() - - @property - def characters(self): - return self._characters - - @characters.setter - def characters(self, characters): - self._characters = characters - self._vocab = self.create_vocab() - - @property - def punctuations(self): - return self._punctuations - - @punctuations.setter - def punctuations(self, punctuations): - self._punctuations = punctuations - self._vocab = self.create_vocab() - - @property - def pad(self): - return self._pad - - @pad.setter - def pad(self, pad): - self._pad = pad - self._vocab = self.create_vocab() - - @property - def eos(self): - return self._eos - - @eos.setter - def eos(self, eos): - self._eos = eos - self._vocab = self.create_vocab() - - @property - def bos(self): - return self._bos - - @bos.setter - def bos(self, bos): - self._bos = bos - self._vocab = self.create_vocab() - - @property - def blank(self): - return self._bos - - @bos.setter - def blank(self, bos): - self._bos = bos - self._vocab = self.create_vocab() - - @property - def vocab(self): - return self._vocab - - @property - def num_chars(self): - return len(self._vocab) - - def _create_vocab(self): - _vocab = self.characters - if self.is_unique: - _vocab = list(set(_vocab)) - if self.is_sorted: - _vocab = sorted(_vocab) - _vocab = list(_vocab) - _vocab = [self.bos] + _vocab if len(self.bos) > 0 and self.bos is not None else _vocab - _vocab = [self.eos] + _vocab if len(self.bos) > 0 and self.eos is not None else _vocab - _vocab = [self.pad] + _vocab if len(self.bos) > 0 and self.pad is not None else _vocab - self._vocab = _vocab + list(self._punctuations) - self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} - assert len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) - - def char_to_id(self, char: str) -> int: - return self._char_to_id[char] - - def id_to_char(self, idx: int) -> str: - return self._id_to_char[idx] - - @staticmethod - def init_from_config(config: "Coqpit"): - return BaseCharacters( - **config.characters if config.characters is not None else {}, - ) - - -class IPAPhonemes(BaseCharacters): - """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary - - Intended to be used with models using IPAPhonemes as input. - It uses system defaults for the undefined class arguments. - - Args: - characters (str): - Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. - - punctuations (str): - Characters to be treated as punctuation. Defaults to `_punctuations`. - - pad (str): - Special padding character that would be ignored by the model. Defaults to `_pad`. - - eos (str): - End of the sentence character. Defaults to `_eos`. - - bos (str): - Beginning of the sentence character. Defaults to `_bos`. - - is_unique (bool): - Remove duplicates from the provided characters. Defaults to True. - - is_sorted (bool): - Sort the characters in alphabetical order. Defaults to True. - """ - - def __init__( - self, - characters: str = _phonemes, - punctuations: str = _punctuations, - pad: str = _pad, - eos: str = _eos, - bos: str = _bos, - is_unique: bool = True, - is_sorted: bool = True, - ) -> None: - super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted) - - @staticmethod - def init_from_config(config: "Coqpit"): - return IPAPhonemes( - **config.characters if config.characters is not None else {}, - ) - - -class Graphemes(BaseCharacters): - """🐸Graphemes class to manage `TTS.tts` model vocabulary - - Intended to be used with models using graphemes as input. - It uses system defaults for the undefined class arguments. - - Args: - characters (str): - Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. - - punctuations (str): - Characters to be treated as punctuation. Defaults to `_punctuations`. - - pad (str): - Special padding character that would be ignored by the model. Defaults to `_pad`. - - eos (str): - End of the sentence character. Defaults to `_eos`. - - bos (str): - Beginning of the sentence character. Defaults to `_bos`. - - is_unique (bool): - Remove duplicates from the provided characters. Defaults to True. - - is_sorted (bool): - Sort the characters in alphabetical order. Defaults to True. - """ - - def __init__( - self, - characters: str = _characters, - punctuations: str = _punctuations, - pad: str = _pad, - eos: str = _eos, - bos: str = _bos, - is_unique: bool = True, - is_sorted: bool = True, - ) -> None: - super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted) - - @staticmethod - def init_from_config(config: "Coqpit"): - return Graphemes( - **config.characters if config.characters is not None else {}, - ) - - - -if __name__ == "__main__": - gr = Graphemes() - ph = IPAPhonemes() - - print(gr.vocab) - print(ph.vocab) - - print(gr.num_chars) - assert "a" == gr.id_to_char(gr.char_to_id("a")) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index f6803ff6..775b4cb8 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name -from TTS.tts.utils.text.symbols import Graphemes, IPAPhonemes +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes class TTSTokenizer: @@ -117,4 +117,6 @@ class TTSTokenizer: phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs) else: characters = Graphemes().init_from_config(config) - return TTSTokenizer(config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars) \ No newline at end of file + return TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ) diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index fe4a9d9b..4762a77a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 From 99d9bb7a174fce17f13c9f988e75d76703023066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:29:24 +0100 Subject: [PATCH 040/214] Test Phonemizers --- tests/text_tests/test_characters.py | 49 +++++----- tests/text_tests/test_phonemizer.py | 144 ++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 25 deletions(-) create mode 100644 tests/text_tests/test_phonemizer.py diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 5a051ac4..ed84b5b4 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,12 +1,6 @@ import unittest -from TTS.tts.utils.text.characters import ( - BaseCharacters, - IPAPhonemes, - Graphemes, - create_graphemes, - create_phonemes, -) +from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, create_graphemes, create_phonemes def test_make_symbols(): @@ -16,16 +10,7 @@ def test_make_symbols(): class BaseCharacterTest(unittest.TestCase): def setUp(self): - self.characters_empty = BaseCharacters( - "", - "", - pad="", - eos="", - bos="", - blank="", - is_unique=True, - is_sorted=True - ) + self.characters_empty = BaseCharacters("", "", pad="", eos="", bos="", blank="", is_unique=True, is_sorted=True) def test_default_character_sets(self): """Test initiation of default character sets""" @@ -41,8 +26,10 @@ class BaseCharacterTest(unittest.TestCase): self.characters_empty.bos = "[BOS]" self.characters_empty.blank = "[BLANK]" - self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) - + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) def test_unique_sorted(self): """Test if the unique and sorted option works""" @@ -53,7 +40,10 @@ class BaseCharacterTest(unittest.TestCase): self.characters_empty.bos = "[BOS]" self.characters_empty.blank = "[BLANK]" - self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) def test_setters_getters(self): """Test the class setters behaves as expected""" @@ -71,16 +61,26 @@ class BaseCharacterTest(unittest.TestCase): self.characters_empty.eos = "[EOS]" self.assertEqual(self.characters_empty._eos, "[EOS]") - self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + self.assertEqual( + self.characters_empty.vocab, ["[PAD]", "[EOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "] + ) self.characters_empty.bos = "[BOS]" self.assertEqual(self.characters_empty._bos, "[BOS]") - self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + self.assertEqual( + self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "] + ) self.characters_empty.blank = "[BLANK]" self.assertEqual(self.characters_empty._blank, "[BLANK]") - self.assertEqual(self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) - self.assertEqual(self.characters_empty.num_chars, len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "])) + self.assertEqual( + self.characters_empty.vocab, + ["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "], + ) + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) self.characters_empty.print_log() @@ -124,4 +124,3 @@ class BaseCharacterTest(unittest.TestCase): self.assertEqual(self.characters_empty.id_to_char(11), "!") self.assertEqual(self.characters_empty.id_to_char(12), "?") self.assertEqual(self.characters_empty.id_to_char(13), " ") - diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py new file mode 100644 index 00000000..cd0adfe1 --- /dev/null +++ b/tests/text_tests/test_phonemizer.py @@ -0,0 +1,144 @@ +import unittest + +from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, create_graphemes, create_phonemes +from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer +from TTS.tts.utils.text.tokenizer import TTSTokenizer + +EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" + + +class TestEspeakPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ESpeak(language="en-us") + self.EXPECTED_PHONEMES = "ɹ|ˈiː|s|ə|n|t ɹ|ɪ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ᵻ|t|ˌeɪ|ɾ|ɪ|ŋ f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|n|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ|ɹ ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ɪ|s|p|ˈɑː|n|s|ə|b|əl f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|uː|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!" + + def test_phonemize(self): + output = self.phonemizer.phonemize(EXAMPLE_TEXT, separator="|") + self.assertEqual(output, self.EXPECTED_PHONEMES) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "espeak") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestGruutPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False) + self.EXPECTED_PHONEMES = "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ| f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l| f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!" + + def test_phonemize(self): + output = self.phonemizer.phonemize(EXAMPLE_TEXT, separator="|") + self.assertEqual(output, self.EXPECTED_PHONEMES) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "gruut") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), list) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestJA_JPPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = JA_JP_Phonemizer() + self._TEST_CASES = """ + どちらに行きますか?/dochiraniikimasuka? + 今日は温泉に、行きます。/kyo:waoNseNni,ikimasu. + 「A」から「Z」までです。/e:karazeqtomadedesu. + そうですね!/so:desune! + クジラは哺乳類です。/kujirawahonyu:ruidesu. + ヴィディオを見ます。/bidioomimasu. + 今日は8月22日です/kyo:wahachigatsuniju:ninichidesu + xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma + 値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu + """ + + def test_phonemize(self): + for line in self._TEST_CASES.strip().split("\n"): + text, phone = line.split("/") + self.assertEqual(self.phonemizer.phonemize(text, separator=""), phone) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "ja_jp_phonemizer") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestZH_CN_Phonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ZH_CN_Phonemizer() + self._TEST_CASES = "" + + def test_phonemize(self): + # TODO: implement ZH phonemizer tests + pass + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "zh_cn_phonemizer") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) \ No newline at end of file From 20e5dd367843aa38a73212ad5da455e67ee4ad19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:35:38 +0100 Subject: [PATCH 041/214] Add doc examples --- TTS/tts/utils/text/phonemizers/espeak_wrapper.py | 12 ++++++++++-- TTS/tts/utils/text/phonemizers/gruut_wrapper.py | 7 +++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 59c4a8ee..45169c17 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -66,6 +66,14 @@ class ESpeak(BasePhonemizer): keep_puncs (bool): If True, keep the punctuations after phonemization. Defaults to True. + + Example: + + >>> from TTS.tts.utils.text.phonemizers import ESpeak + >>> phonemizer = ESpeak("tr") + >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|") + 'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.' + """ _ESPEAK_LIB = _DEF_ESPEAK_LIB @@ -140,14 +148,14 @@ class ESpeak(BasePhonemizer): count += 1 return langs - def version(self): + def version(self) -> str: """Get the version of the used backend. Returns: str: Version of the used backend. """ args = ["--version"] - for line in self._espeak_exe(args, sync=True): + for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): version = line.decode("utf8").strip().split()[2] logging.debug("line: %s" % repr(line)) return version diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py index a1ad1b80..d0aa469e 100644 --- a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -30,6 +30,13 @@ class Gruut(BasePhonemizer): keep_stress (bool): If true, keep the stress characters after phonemization. Defaults to False. + + Example: + + >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + >>> phonemizer = Gruut('en-us') + >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|") + 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?' """ def __init__( From f0655bfffc1e70d161bc7ba7339dfc62f22a1ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:36:36 +0100 Subject: [PATCH 042/214] Fix ja_jp_phonemizer --- .../text/phonemizers/ja_jp_phonemizer.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py index fcd170ba..714e9832 100644 --- a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -5,30 +5,50 @@ from TTS.tts.utils.text.phonemizers.base import BasePhonemizer _DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" +_TRANS_TABLE = {"、": ","} + + +def trans(text): + for i, j in _TRANS_TABLE.items(): + text = text.replace(i, j) + return text + class JA_JP_Phonemizer(BasePhonemizer): """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` TODO: someone with JA knowledge should check this implementation + + Example: + + >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer + >>> phonemizer = JA_JP_Phonemizer() + >>> phonemizer.phonemize("どちらに行きますか?", separator="|") + d|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|? + """ language = "ja-jp" - def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=False, **kwargs): + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) @staticmethod def name(): return "ja_jp_phonemizer" - def phonemize_jajp(self, text: str, separator: str = "|") -> str: + def _phonemize(self, text: str, separator: str = "|") -> str: ph = japanese_text_to_phonemes(text) if separator is not None or separator != "": return separator.join(ph) return ph - def _phonemize(self, text, separator): - return self.phonemize_jajp(text, separator) + def phonemize(self, text: str, separator="|") -> str: + """Custom phonemize for JP_JA + + Skip pre-post processing steps used by the other phonemizers. + """ + return self._phonemize(text, separator) @staticmethod def supported_languages() -> Dict: From 10d435ce77d0240c8ede5a8f68f916bf188445f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:36:50 +0100 Subject: [PATCH 043/214] Fixup --- TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py index 714e9832..4f93edeb 100644 --- a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -24,7 +24,7 @@ class JA_JP_Phonemizer(BasePhonemizer): >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer >>> phonemizer = JA_JP_Phonemizer() >>> phonemizer.phonemize("どちらに行きますか?", separator="|") - d|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|? + 'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?' """ From ff7c3858389ba250f761d76592cb060ac6be05c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:37:12 +0100 Subject: [PATCH 044/214] Fix BasePhonemizer --- TTS/tts/utils/text/phonemizers/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index a14a73bc..249c8bce 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -92,8 +92,12 @@ class BasePhonemizer(abc.ABC): def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: """Preprocess the text before phonemization + 1. remove spaces + 2. remove punctuation + Override this if you need a different behaviour """ + text = text.strip() if self._keep_puncs: # a tuple (text, punctuation marks) return self._punctuator.strip_to_restore(text) From d8bdeb8b8f0d06d3f0623be209cf9adeb9eeacd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 10:39:21 +0100 Subject: [PATCH 045/214] Fix Punctuation --- TTS/tts/utils/text/characters.py | 52 +++++++++++++++---------------- TTS/tts/utils/text/punctuation.py | 7 +++-- TTS/tts/utils/text/tokenizer.py | 2 +- 3 files changed, 32 insertions(+), 29 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 05882e57..c1342e78 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -74,36 +74,36 @@ phonemes = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) class BaseCharacters: """🐸BaseCharacters class - Every new character class should inherit from this. + Every new character class should inherit from this. - Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. - If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. - Args: - characters (str): - Main set of characters to be used in the vocabulary. + Args: + characters (str): + Main set of characters to be used in the vocabulary. - punctuations (str): - Characters to be treated as punctuation. + punctuations (str): + Characters to be treated as punctuation. - pad (str): - Special padding character that would be ignored by the model. + pad (str): + Special padding character that would be ignored by the model. - eos (str): - End of the sentence character. + eos (str): + End of the sentence character. - bos (str): - Beginning of the sentence character. + bos (str): + Beginning of the sentence character. - blank (str): - Optional character used between characters by some models for better prosody. + blank (str): + Optional character used between characters by some models for better prosody. - is_unique (bool): - Remove duplicates from the provided characters. Defaults to True. -el - is_sorted (bool): - Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. """ def __init__( @@ -196,10 +196,10 @@ el if self.is_sorted: _vocab = sorted(_vocab) _vocab = list(_vocab) - _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab - _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab - _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab - _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab self._vocab = _vocab + list(self._punctuations) self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} @@ -214,7 +214,7 @@ el def id_to_char(self, idx: int) -> str: return self._id_to_char[idx] - def print_log(self, level:int=0): + def print_log(self, level: int = 0): """ Prints the vocabulary in a nice format. """ diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 624cea88..36e1f194 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -91,10 +91,13 @@ class Punctuation: puncs.append(_PUNC_IDX(match.group(), position)) # convert str text to a List[str], each item is separated by a punctuation splitted_text = [] - for punc in puncs: + for idx, punc in enumerate(puncs): split = text.split(punc.punc) prefix, suffix = split[0], punc.punc.join(split[1:]) splitted_text.append(prefix) + # if the text does not end with a punctuation, add it to the last item + if idx == len(puncs) - 1 and len(suffix) > 0: + splitted_text.append(suffix) text = suffix return splitted_text, puncs @@ -126,7 +129,7 @@ class Punctuation: current = puncs[0] if current.position == PuncPosition.BEGIN: - return cls._restore([current.mark + text[0]] + text[1:], puncs[1:], num) + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num) if current.position == PuncPosition.END: return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 775b4cb8..8c231c14 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -1,8 +1,8 @@ from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners -from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name class TTSTokenizer: From 79a84410f2f01091b745bb8d940989ff3c8b36a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 17:17:11 +0100 Subject: [PATCH 046/214] Test punctuations --- TTS/tts/utils/text/punctuation.py | 19 +++++++++++++++--- tests/text_tests/test_punctuation.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 tests/text_tests/test_punctuation.py diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 36e1f194..414ac253 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -19,12 +19,25 @@ class PuncPosition(Enum): class Punctuation: - """Handle punctuations characters in text. + """Handle punctuations in text. Just strip punctuations from text or strip and restore them later. Args: puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' """ def __init__(self, puncs: str = _DEF_PUNCS): @@ -43,7 +56,7 @@ class Punctuation: def puncs(self, value): if not isinstance(value, six.string_types): raise ValueError("[!] Punctuations must be of type str.") - self._puncs = "".join(set(value)) + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") def strip(self, text): @@ -56,7 +69,7 @@ class Punctuation: "This is. example !" -> "This is example " """ - return re.sub(self.puncs_regular_exp, " ", text).strip() + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() def strip_to_restore(self, text): """Remove punctuations from text to restore them later. diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py new file mode 100644 index 00000000..f349bc50 --- /dev/null +++ b/tests/text_tests/test_punctuation.py @@ -0,0 +1,30 @@ +import unittest +from TTS.tts.utils.text.punctuation import Punctuation, _DEF_PUNCS + +class PunctuationTest(unittest.TestCase): + def setUp(self): + self.punctuation = Punctuation() + self.test_texts = [("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"), + ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), + ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), + ("This, is my text to be striped from text", "This is my text to be striped from text") + ] + + def test_get_set_puncs(self): + self.punctuation.puncs = "-=" + self.assertEqual(self.punctuation.puncs, "-=") + + self.punctuation.puncs = _DEF_PUNCS + self.assertEqual(self.punctuation.puncs, _DEF_PUNCS) + + def test_strip_punc(self): + for text, gt in self.test_texts: + text_striped = self.punctuation.strip(text) + self.assertEqual(text_striped, gt) + + def test_strip_restore(self): + for text, gt in self.test_texts: + text_striped, puncs_map = self.punctuation.strip_to_restore(text) + text_restored = self.punctuation.restore(text_striped, puncs_map) + self.assertEqual(' '.join(text_striped), gt) + self.assertEqual(text_restored[0], text) From ba3b60c90f7caf7ab9f7edec1e42817412dc00b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 18:04:57 +0100 Subject: [PATCH 047/214] Test TTSTokenizer --- TTS/tts/utils/text/tokenizer.py | 17 ++++-- tests/text_tests/test_tokenizer.py | 88 ++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 tests/text_tests/test_tokenizer.py diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 8c231c14..e79cf5e5 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -21,6 +21,14 @@ class TTSTokenizer: phonemizer (Phonemizer): A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + Example: + + >>> from TTS.tts.utils.text.tokenizer import TTSTokenizer + >>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + >>> text = "Hello world!" + >>> ids = tokenizer.text_to_ids(text) + >>> text_hat = tokenizer.ids_to_text(ids) + >>> assert text == text_hat """ def __init__( @@ -89,21 +97,22 @@ class TTSTokenizer: return [self.characters.bos] + list(char_sequence) + [self.characters.eos] def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): - char_to_use = self.characters.blank_char if use_blank_char else self.characters.pad + char_to_use = self.characters.blank if use_blank_char else self.characters.pad result = [char_to_use] * (len(char_sequence) * 2 + 1) result[1::2] = char_sequence return result - def print_logs(self, level: int = 1): + def print_logs(self, level: int = 0): indent = "\t" * level print(f"{indent}| > add_blank: {self.use_phonemes}") print(f"{indent}| > use_eos_bos: {self.use_phonemes}") print(f"{indent}| > use_phonemes: {self.use_phonemes}") - print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}") + if self.use_phonemes: + print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}") @staticmethod def init_from_config(config: "Coqpit"): - """Init Tokenizer object from the config. + """Init Tokenizer object from config Args: config (Coqpit): Coqpit model config. diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py new file mode 100644 index 00000000..6b7982cd --- /dev/null +++ b/tests/text_tests/test_tokenizer.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from os import sep +import unittest + +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes, _phonemes, _punctuations, _eos, _bos, _pad, _blank +from TTS.tts.utils.text.phonemizers import ESpeak + +from coqpit import Coqpit + + +class TestTTSTokenizer(unittest.TestCase): + def setUp(self): + self.tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + + self.ph = ESpeak('tr') + self.tokenizer_ph = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) + + def test_encode_decode_graphemes(self): + text = "This is, a test." + ids = self.tokenizer.encode(text) + test_hat = self.tokenizer.decode(ids) + self.assertEqual(text, test_hat) + self.assertEqual(len(ids), len(text)) + + def test_text_to_ids_phonemes(self): + # TODO: note sure how to extend to cover all the languages and phonemizer. + text = "Bu bir Örnek." + text_ph = self.ph.phonemize(text, separator="") + ids = self.tokenizer_ph.text_to_ids(text) + test_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + + def test_text_to_ids_phonemes_with_eos_bos(self): + text = "Bu bir Örnek." + self.tokenizer_ph.use_eos_bos = True + text_ph = IPAPhonemes().bos + self.ph.phonemize(text, separator="") + IPAPhonemes().eos + ids = self.tokenizer_ph.text_to_ids(text) + test_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + + def test_text_to_ids_phonemes_with_eos_bos_and_blank(self): + text = "Bu bir Örnek." + self.tokenizer_ph.use_eos_bos = True + self.tokenizer_ph.add_blank = True + text_ph = "bʊ bɪr œrnˈɛc." + ids = self.tokenizer_ph.text_to_ids(text) + text_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, text_hat) + + def test_print_logs(self): + self.tokenizer.print_logs() + self.tokenizer_ph.print_logs() + + def test_init_from_config(self): + + @dataclass + class Characters(Coqpit): + characters: str = _phonemes + punctuations: str = _punctuations + pad: str = _pad + eos: str = _eos + bos: str = _bos + blank: str = _blank + is_unique: bool = True + is_sorted: bool = True + + @dataclass + class TokenizerConfig(Coqpit): + enable_eos_bos_chars: bool = True + use_phonemes: bool = True + add_blank: bool = False + characters: str = Characters() + phonemizer: str = "espeak" + phoneme_language: str = "tr" + text_cleaner: str = "phoneme_cleaners" + characters = Characters() + + tokenizer_ph = TTSTokenizer.init_from_config(TokenizerConfig()) + text = "Bu bir Örnek." + text_ph = "" + self.ph.phonemize(text, separator="") + "" + ids = tokenizer_ph.text_to_ids(text) + test_hat = tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + + + + From f1ea3ad1825530416177c335d282f8bb38e9710f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 19 Nov 2021 18:07:00 +0100 Subject: [PATCH 048/214] Remove old text processing tests --- tests/aux_tests/test_text_processing.py | 104 ------------------------ 1 file changed, 104 deletions(-) delete mode 100644 tests/aux_tests/test_text_processing.py diff --git a/tests/aux_tests/test_text_processing.py b/tests/aux_tests/test_text_processing.py deleted file mode 100644 index 62d60a42..00000000 --- a/tests/aux_tests/test_text_processing.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Tests for text to phoneme converstion""" -import unittest - -from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme, text2phone - -# ----------------------------------------------------------------------------- - -LANG = "en-us" - -EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" - -EXPECTED_PHONEMES = "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ| f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s|,| ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l| f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ|!" - -# ----------------------------------------------------------------------------- - - -class TextProcessingTestCase(unittest.TestCase): - """Tests for text to phoneme conversion""" - - def test_phoneme_to_sequence(self): - """Verify en-us sentence phonemes without blank token""" - self._test_phoneme_to_sequence(add_blank=False) - - def test_phoneme_to_sequence_with_blank_token(self): - """Verify en-us sentence phonemes with blank token""" - self._test_phoneme_to_sequence(add_blank=True) - - def _test_phoneme_to_sequence(self, add_blank): - """Verify en-us sentence phonemes""" - text_cleaner = ["phoneme_cleaners"] - sequence = phoneme_to_sequence(EXAMPLE_TEXT, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = EXPECTED_PHONEMES.replace("|", "") - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # multiple punctuations - text = "Be a voice, not an! echo?" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # not ending with punctuation - text = "Be a voice, not an! echo" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # original - text = "Be a voice, not an echo!" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # extra space after the sentence - text = "Be a voice, not an! echo. " - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # extra space after the sentence - text = "Be a voice, not an! echo. " - sequence = phoneme_to_sequence( - text, text_cleaner, LANG, enable_eos_bos=True, add_blank=add_blank, use_espeak_phonemes=True - ) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - def test_text2phone(self): - """Verify phones directly (with |)""" - ph = text2phone(EXAMPLE_TEXT, LANG, use_espeak_phonemes=True) - self.assertEqual(ph, EXPECTED_PHONEMES) - - -# ----------------------------------------------------------------------------- - -if __name__ == "__main__": - unittest.main() From 353f913efc500bdca73b2b77f57bc0c5d8ef1dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 12:48:16 +0100 Subject: [PATCH 049/214] Fix #985 --- TTS/vocoder/models/gan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 76fee505..e9bab982 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -325,6 +325,7 @@ class GAN(BaseVocoder): data_items (List): Data samples. verbose (bool): Log information if true. num_gpus (int): Number of GPUs in use. + rank (int): Rank of the current GPU. Defaults to None. Returns: DataLoader: Torch dataloader. From e1b4c4ca43d390a2fe19a38b8a7660721532bc1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 12:59:51 +0100 Subject: [PATCH 050/214] Add init_from_config to GAN --- TTS/vocoder/models/gan.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index e9bab982..b4e3652e 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -361,3 +361,7 @@ class GAN(BaseVocoder): def get_criterion(self): """Return criterions for the optimizers""" return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)] + + @staticmethod + def init_from_config(config: Coqpit) -> "GAN": + return GAN(config) \ No newline at end of file From acc6eef625d0e65aa9c762aff3a3048898864688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 17:49:20 +0100 Subject: [PATCH 051/214] Update for tokenizer API --- TTS/utils/synthesizer.py | 14 +++++--------- TTS/vocoder/models/__init__.py | 3 +-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 12b71ab6..2e4f4735 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,8 +114,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) + self.tts_model = setup_tts_model(config=self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -245,7 +244,7 @@ class Synthesizer(object): path (str): output path to save the waveform. """ wav = np.array(wav) - self.ap.save_wav(wav, path, self.output_sample_rate) + self.tts_model.ap.save_wav(wav, path, self.output_sample_rate) def tts( self, @@ -333,13 +332,10 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, - ap=self.ap, - tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, style_wav=style_wav, - enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, use_griffin_lim=use_gl, d_vector=speaker_embedding, ) @@ -347,14 +343,14 @@ class Synthesizer(object): mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() if not use_gl: # denormalize tts output based on tts audio config - mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch scale_factor = [ 1, - self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: print(" > interpolating tts model output.") @@ -372,7 +368,7 @@ class Synthesizer(object): # trim silence if self.tts_config.audio["do_trim_silence"] is True: - waveform = trim_silence(waveform, self.ap) + waveform = trim_silence(waveform, self.tts_model.ap) wavs += list(waveform) wavs += [0] * 10000 diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index a70ebe40..65901617 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -28,8 +28,7 @@ def setup_model(config: Coqpit): except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e print(" > Vocoder Model: {}".format(config.model)) - model = MyModel(config) - return model + return MyModel.init_from_config(config) def setup_generator(c): From 693fb4dd3973e7ced1a33fbe9121fa461a54661f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 17:49:58 +0100 Subject: [PATCH 052/214] Modify init_from_config for IPAPhonemes --- TTS/tts/utils/text/characters.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index c1342e78..d27a562a 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -257,6 +257,9 @@ class IPAPhonemes(BaseCharacters): bos (str): Beginning of the sentence character. Defaults to `_bos`. + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + is_unique (bool): Remove duplicates from the provided characters. Defaults to True. @@ -279,9 +282,24 @@ class IPAPhonemes(BaseCharacters): @staticmethod def init_from_config(config: "Coqpit"): - return IPAPhonemes( - **config.characters if config.characters is not None else {}, - ) + # band-aid for compatibility with old models + characters = None + if "characters" in config: + if "phonemes" in config.characters: + config.characters["characters"] = config.characters["phonemes"] + # delattr(config.characters, "phonemes") + + return IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ) + return characters class Graphemes(BaseCharacters): From 22f0c58fe1d7f8ef67eb75d034708cbd81936926 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:37:25 +0100 Subject: [PATCH 053/214] Print language codes --- TTS/tts/utils/text/phonemizers/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index c0ef7909..b00f7f5e 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -49,3 +49,7 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: if name == "ja_jp_phonemizer": return JA_JP_Phonemizer(**kwargs) raise ValueError(f"Phonemizer {name} not found") + + +if __name__ == "__main__": + print(DEF_LANG_TO_PHONEMIZER) \ No newline at end of file From 4e83bf396811a2acefeca0202fd9e23d540e90ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:37:54 +0100 Subject: [PATCH 054/214] Allow choosing phonemizer --- TTS/tts/utils/text/tokenizer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index e79cf5e5..4163a0e2 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -117,14 +117,22 @@ class TTSTokenizer: Args: config (Coqpit): Coqpit model config. """ + # init cleaners if isinstance(config.text_cleaner, (str, list)): text_cleaner = getattr(cleaners, config.text_cleaner) if config.use_phonemes: + # init phoneme set characters = IPAPhonemes().init_from_config(config) phonemizer_kwargs = {"language": config.phoneme_language} - phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs) + + # init phonemizer + if "phonemizer" in config and config.phonemizer: + phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) + else: + phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs) else: + # init character set characters = Graphemes().init_from_config(config) return TTSTokenizer( config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars From d8ec7086b6fe6490fbe24b391fd3f2f7daa2dbf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:41:21 +0100 Subject: [PATCH 055/214] Update `synthesis` for the new API --- TTS/tts/utils/synthesis.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index a4f4b0c8..9d9660aa 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -113,8 +113,6 @@ def synthesis( text, CONFIG, use_cuda, - ap, - tokenizer, speaker_id=None, style_wav=None, use_griffin_lim=False, @@ -139,9 +137,6 @@ def synthesis( use_cuda (bool): Enable/disable CUDA. - ap (TTS.tts.utils.audio.AudioProcessor): - The audio processor for extracting features and pre/post-processing audio. - speaker_id (int): Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. @@ -169,10 +164,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) # convert text to sequence of token IDs text_inputs = np.asarray( - tokenizer.text_to_ids(text), + model.tokenizer.text_to_ids(text), dtype=np.int32, ) # pass tensors to backend From 3de9f38d166c581f3e7943d90ec31c7729c52727 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:41:47 +0100 Subject: [PATCH 056/214] Add init_from_config to SpeakerManager --- TTS/tts/utils/speakers.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 441296ac..c556db79 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -318,6 +318,30 @@ class SpeakerManager: # TODO: implement speaker encoder raise NotImplementedError + @staticmethod + def init_from_config(config: "Coqpit"): + """Initialize a speaker manager from config + + Args: + config (Coqpit): Config object. + + Returns: + SpeakerEncoder: Speaker encoder object. + """ + speaker_manager = None + if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True: + if config.get("speaker_file", None): + speaker_manager = SpeakerManager(speaker_id_file_path=config.speaker_file) + if config.get("speakers_file", None): + speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) + + if hasattr(config, "use_d_vector_file") and config.use_speaker_embedding is True: + if config.get("speakers_file", None): + speaker_manager = SpeakerManager(d_vectors_file_path=config.speaker_file) + if config.get("d_vector_file", None): + speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + return speaker_manager + def _set_file_path(path): """Find the speakers.json under the given path or the above it. From 87bf940676b03c6c3d5a255cb1b81bc68b86092d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:42:00 +0100 Subject: [PATCH 057/214] Print duplicate characters --- TTS/tts/utils/text/characters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index d27a562a..d2872c73 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -206,7 +206,7 @@ class BaseCharacters: if self.is_unique: assert ( len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) - ), f" [!] There are duplicate characters in the character set." + ), f" [!] There are duplicate characters in the character set. {set([x for x in self.vocab if self.vocab.count(x) > 1])}" def char_to_id(self, char: str) -> int: return self._char_to_id[char] From 73d27ebd45584020aafa9447ef8a22012b4ddb0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:42:44 +0100 Subject: [PATCH 058/214] Fix GlowTTS --- TTS/tts/models/glow_tts.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 9e779f8e..3e5226aa 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -513,3 +514,22 @@ class GlowTTS(BaseTTS): def on_train_step_start(self, trainer): """Decide on every training step wheter enable/disable data depended initialization.""" self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + + # init characters + if config.use_phonemes: + from TTS.tts.utils.text.characters import IPAPhonemes + characters = IPAPhonemes().init_from_config(config) + else: + from TTS.tts.utils.text.characters import Graphemes + characters = Graphemes().init_from_config(config) + config.num_chars = characters.num_chars + + from TTS.utils.audio import AudioProcessor + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return GlowTTS(config, ap, tokenizer, speaker_manager) \ No newline at end of file From d2525abe8c0f05e284d3e053b83351d30f8d0f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:43:48 +0100 Subject: [PATCH 059/214] Remove get_characters from BaseTTS --- TTS/tts/models/base_tts.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 98f68742..8bb7f02e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -10,12 +10,10 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel -from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.characters import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -74,22 +72,6 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - # @staticmethod - # def get_characters(config: Coqpit) -> str: - # # TODO: implement CharacterProcessor - # if config.characters is not None: - # symbols, phonemes = make_symbols(**config.characters) - # else: - # from TTS.tts.utils.text.characters import parse_symbols, phonemes, symbols - - # if config.use_phonemes: - - # config.characters = Graphemes() - - # model_characters = phonemes if config.use_phonemes else symbols - # num_chars = len(model_characters) + getattr(config, "add_blank", False) - # return model_characters, config, num_chars - def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) From 3eca5ad0605c717b26437541263ec97d173b8b8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:44:18 +0100 Subject: [PATCH 060/214] Update config fields for phonemizer --- TTS/tts/configs/shared_configs.py | 40 ++++++++++++++++--------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 65ed21de..b101b70a 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -50,7 +50,7 @@ class GSTConfig(Coqpit): @dataclass class CharactersConfig(Coqpit): - """Defines character or phoneme set used by the model + """Defines arguments for the `BaseCharacters` and its subclasses. Args: pad (str): @@ -62,6 +62,9 @@ class CharactersConfig(Coqpit): bos (str): characters showing the beginning of a sentence. Defaults to None. + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + characters (str): character set used by the model. Characters not in this list are ignored when converting input text to a list of sequence IDs. Defaults to None. @@ -70,32 +73,26 @@ class CharactersConfig(Coqpit): characters considered as punctuation as parsing the input sentence. Defaults to None. phonemes (str): - characters considered as parsing phonemes. Defaults to None. + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. - unique (bool): + is_unique (bool): remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old models trained with character lists with duplicates. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. """ pad: str = None eos: str = None bos: str = None + blank: str = None characters: str = None punctuations: str = None phonemes: str = None - unique: bool = True # for backwards compatibility of models trained with char sets with duplicates - - def check_values( - self, - ): - """Check config fields""" - c = asdict(self) - check_argument("pad", c, prerequest="characters", restricted=True) - check_argument("eos", c, prerequest="characters", restricted=True) - check_argument("bos", c, prerequest="characters", restricted=True) - check_argument("characters", c, prerequest="characters", restricted=True) - check_argument("phonemes", c, restricted=True) - check_argument("punctuations", c, prerequest="characters", restricted=True) + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True @dataclass @@ -110,8 +107,13 @@ class BaseTTSConfig(BaseTrainingConfig): use_phonemes (bool): enable / disable phoneme use. - use_espeak_phonemes (bool): - enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`). + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. compute_input_seq_cache (bool): enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of @@ -195,7 +197,7 @@ class BaseTTSConfig(BaseTrainingConfig): audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) # phoneme settings use_phonemes: bool = False - use_espeak_phonemes: bool = True + phonemizer: str = None phoneme_language: str = None compute_input_seq_cache: bool = False text_cleaner: str = None From 9b83e665fcef1c2ef8928ec81f06d7cd74e0863d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:45:32 +0100 Subject: [PATCH 061/214] Add init_from_config as an abstract class --- TTS/model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/TTS/model.py b/TTS/model.py index a7c64dde..efa00b2a 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -130,6 +130,15 @@ class BaseModel(nn.Module, ABC): """ ... + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit): + """Init the model from given config. + + Override this depending on your model. + """ + pass + def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: """Setup an return optimizer or optimizers.""" pass @@ -150,3 +159,4 @@ class BaseModel(nn.Module, ABC): def format_batch(self): pass + From bb389479a46b5843ce765381a0ddff3078cacbb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:46:18 +0100 Subject: [PATCH 062/214] Update setup_model for TTS.tts models --- TTS/tts/models/__init__.py | 44 ++------------------------------------ 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index c8371106..cb1c2e21 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,52 +1,12 @@ -from TTS.tts.utils.text.characters import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): +def setup_model(config: "Coqpit") -> "BaseTTS": print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) else: MyModel = find_module("TTS.tts.models", config.model.lower()) - # define set of characters used by the model - if config.characters is not None: - # set characters from config - if hasattr(MyModel, "make_symbols"): - symbols = MyModel.make_symbols(config) - else: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.characters import phonemes, symbols # pylint: disable=import-outside-toplevel - - if config.use_phonemes: - symbols = phonemes - # use default characters and assign them to config - config.characters = parse_symbols() - # consider special `blank` character if `add_blank` is set True - num_chars = len(symbols) + getattr(config, "add_blank", False) - config.num_chars = num_chars - # compatibility fix - if "model_params" in config: - config.model_params.num_chars = num_chars - if "model_args" in config: - config.model_args.num_chars = num_chars - if config.model.lower() in ["vits"]: # If model supports multiple languages - model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) - else: - model = MyModel(config, speaker_manager=speaker_manager) + model = MyModel.init_from_config(config) return model - - -# TODO; class registery -# def import_models(models_dir, namespace): -# for file in os.listdir(models_dir): -# path = os.path.join(models_dir, file) -# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): -# model_name = file[: file.find(".py")] if file.endswith(".py") else file -# importlib.import_module(namespace + "." + model_name) -# -# -## automatically import any Python files in the models/ directory -# models_dir = os.path.dirname(__file__) -# import_models(models_dir, "TTS.tts.models") From 8c8093ce23a0b9730f45bf76620ed96f311879a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:46:48 +0100 Subject: [PATCH 063/214] Make style --- tests/text_tests/test_phonemizer.py | 2 +- tests/text_tests/test_punctuation.py | 17 ++++++++++------- tests/text_tests/test_tokenizer.py | 19 +++++++------------ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py index cd0adfe1..aa7a5499 100644 --- a/tests/text_tests/test_phonemizer.py +++ b/tests/text_tests/test_phonemizer.py @@ -141,4 +141,4 @@ class TestZH_CN_Phonemizer(unittest.TestCase): self.assertIsInstance(self.phonemizer.version(), str) def test_is_available(self): - self.assertTrue(self.phonemizer.is_available()) \ No newline at end of file + self.assertTrue(self.phonemizer.is_available()) diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py index f349bc50..141c10e4 100644 --- a/tests/text_tests/test_punctuation.py +++ b/tests/text_tests/test_punctuation.py @@ -1,14 +1,17 @@ import unittest -from TTS.tts.utils.text.punctuation import Punctuation, _DEF_PUNCS + +from TTS.tts.utils.text.punctuation import _DEF_PUNCS, Punctuation + class PunctuationTest(unittest.TestCase): def setUp(self): self.punctuation = Punctuation() - self.test_texts = [("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"), - ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), - ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), - ("This, is my text to be striped from text", "This is my text to be striped from text") - ] + self.test_texts = [ + ("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"), + ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), + ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), + ("This, is my text to be striped from text", "This is my text to be striped from text"), + ] def test_get_set_puncs(self): self.punctuation.puncs = "-=" @@ -26,5 +29,5 @@ class PunctuationTest(unittest.TestCase): for text, gt in self.test_texts: text_striped, puncs_map = self.punctuation.strip_to_restore(text) text_restored = self.punctuation.restore(text_striped, puncs_map) - self.assertEqual(' '.join(text_striped), gt) + self.assertEqual(" ".join(text_striped), gt) self.assertEqual(text_restored[0], text) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 6b7982cd..8bee618b 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -1,19 +1,19 @@ +import unittest from dataclasses import dataclass from os import sep -import unittest - -from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes, _phonemes, _punctuations, _eos, _bos, _pad, _blank -from TTS.tts.utils.text.phonemizers import ESpeak from coqpit import Coqpit +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes, _blank, _bos, _eos, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.phonemizers import ESpeak +from TTS.tts.utils.text.tokenizer import TTSTokenizer + class TestTTSTokenizer(unittest.TestCase): def setUp(self): self.tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) - self.ph = ESpeak('tr') + self.ph = ESpeak("tr") self.tokenizer_ph = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) def test_encode_decode_graphemes(self): @@ -53,7 +53,6 @@ class TestTTSTokenizer(unittest.TestCase): self.tokenizer_ph.print_logs() def test_init_from_config(self): - @dataclass class Characters(Coqpit): characters: str = _phonemes @@ -80,9 +79,5 @@ class TestTTSTokenizer(unittest.TestCase): text = "Bu bir Örnek." text_ph = "" + self.ph.phonemize(text, separator="") + "" ids = tokenizer_ph.text_to_ids(text) - test_hat = tokenizer_ph.ids_to_text(ids) + test_hat = tokenizer_ph.ids_to_text(ids) self.assertEqual(text_ph, test_hat) - - - - From c39aaafbfc679c630ee9f672b93f3d5ad0b8333e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 25 Nov 2021 17:28:17 +0100 Subject: [PATCH 064/214] Update EspeakWrapper for espeak-ng --- .../utils/text/phonemizers/espeak_wrapper.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 45169c17..f806f036 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -1,6 +1,5 @@ import logging import subprocess -import tempfile from typing import Dict, List from TTS.tts.utils.text.phonemizers.base import BasePhonemizer @@ -24,6 +23,7 @@ else: def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: cmd = [ espeak_lib, + "-q", "-b", "1", # UTF8 text encoding ] @@ -107,11 +107,20 @@ class ESpeak(BasePhonemizer): with '_'. This option requires espeak>=1.49. Default to False. """ # set arguments - args = ["-q", "-v", f"{self._language}"] + args = ["-v", f"{self._language}"] + # espeak and espeak-ng parses `ipa` differently if tie: - args.append("--ipa=1") # use '͡' between phonemes + # use '͡' between phonemes + if _DEF_ESPEAK_LIB == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") else: - args.append("--ipa=3") # split with '_' + # split with '_' + if _DEF_ESPEAK_LIB == "espeak": + args.append("--ipa=3") + else: + args.append("--ipa=1") if tie: args.append("--tie=%s" % tie) args.append(text) From 0fe39166feffe0537d6ad1d22374874ddb768bc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 25 Nov 2021 17:30:03 +0100 Subject: [PATCH 065/214] Discard OOV chars in tokenizer Discard but store OOV chars with a warninig message when the OOV char first recognized --- TTS/tts/utils/text/tokenizer.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 4163a0e2..3780426a 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -8,6 +8,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize class TTSTokenizer: """🐸TTS tokenizer to convert input characters to token IDs and back. + Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later. + Args: use_phonemes (bool): Whether to use phonemes instead of characters. Defaults to False. @@ -45,14 +47,21 @@ class TTSTokenizer: self.add_blank = add_blank self.use_eos_bos = use_eos_bos self.characters = characters + self.not_found_characters = [] self.phonemizer = phonemizer def encode(self, text: str) -> List[int]: """Encodes a string of text as a sequence of IDs.""" token_ids = [] for char in text: - idx = self.characters.char_to_id(char) - token_ids.append(idx) + try: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + except KeyError: + # discard but store not found characters + if char not in self.not_found_characters: + self.not_found_characters.append(char) + print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") return token_ids def decode(self, token_ids: List[int]) -> str: @@ -109,6 +118,10 @@ class TTSTokenizer: print(f"{indent}| > use_phonemes: {self.use_phonemes}") if self.use_phonemes: print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}") + if len(self.not_found_characters) > 0: + print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + for char in self.not_found_characters: + print(f"{indent}| > {char}") @staticmethod def init_from_config(config: "Coqpit"): From 4e8f9d6f101d7df4cfd08c9af626726f08325649 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 25 Nov 2021 17:30:54 +0100 Subject: [PATCH 066/214] Fix IPAPhonemes init_from_config --- TTS/tts/utils/text/characters.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index d2872c73..f9c44a7d 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -209,7 +209,7 @@ class BaseCharacters: ), f" [!] There are duplicate characters in the character set. {set([x for x in self.vocab if self.vocab.count(x) > 1])}" def char_to_id(self, char: str) -> int: - return self._char_to_id[char] + return self._char_to_id[char] def id_to_char(self, idx: int) -> str: return self._id_to_char[idx] @@ -283,12 +283,9 @@ class IPAPhonemes(BaseCharacters): @staticmethod def init_from_config(config: "Coqpit"): # band-aid for compatibility with old models - characters = None - if "characters" in config: - if "phonemes" in config.characters: + if "characters" in config and config.characters is not None: + if "phonemes" in config.characters and config.characters.phonemes is not None: config.characters["characters"] = config.characters["phonemes"] - # delattr(config.characters, "phonemes") - return IPAPhonemes( characters=config.characters["characters"], punctuations=config.characters["punctuations"], @@ -299,7 +296,10 @@ class IPAPhonemes(BaseCharacters): is_unique=config.characters["is_unique"], is_sorted=config.characters["is_sorted"], ) - return characters + else: + return IPAPhonemes( + **config.characters if config.characters is not None else {}, + ) class Graphemes(BaseCharacters): From 961e98a4610a434c750f3488e0f9896d60c24dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 25 Nov 2021 17:31:25 +0100 Subject: [PATCH 067/214] Add OOV case to tokenizer tests --- tests/text_tests/test_tokenizer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 8bee618b..6c48d276 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -52,6 +52,16 @@ class TestTTSTokenizer(unittest.TestCase): self.tokenizer.print_logs() self.tokenizer_ph.print_logs() + def test_not_found_characters(self): + self.ph = ESpeak("en-us") + self.tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) + self.assertEqual(len(self.tokenizer.not_found_characters), 0) + text = "Yolk of one egg beaten light" + ids = self.tokenizer_local.text_to_ids(text) + text_hat = self.tokenizer_local.ids_to_text(ids) + self.assertEqual(self.tokenizer_local.not_found_characters, ['̩']) + self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt") + def test_init_from_config(self): @dataclass class Characters(Coqpit): From 4894998e6b67ded624259bbc7226248e7eab81ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 25 Nov 2021 17:56:42 +0100 Subject: [PATCH 068/214] Fix print_logs --- TTS/tts/utils/text/tokenizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 3780426a..ada5e57b 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -113,11 +113,12 @@ class TTSTokenizer: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > add_blank: {self.use_phonemes}") - print(f"{indent}| > use_eos_bos: {self.use_phonemes}") + print(f"{indent}| > add_blank: {self.add_blank}") + print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") print(f"{indent}| > use_phonemes: {self.use_phonemes}") if self.use_phonemes: - print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}") + print(f"{indent}| > phonemizer:") + self.phonemizer.print_logs(level + 1) if len(self.not_found_characters) > 0: print(f"{indent}| > {len(self.not_found_characters)} not found characters:") for char in self.not_found_characters: From 3b63d713b945449fcc443b5ac18d83517ef63bff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:46:16 +0100 Subject: [PATCH 069/214] Fix espeak wrapper cmd call --- .../utils/text/phonemizers/espeak_wrapper.py | 28 +++++++++++-------- TTS/vocoder/models/gan.py | 2 +- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index f806f036..f1d0b6cd 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -29,7 +29,11 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: ] cmd.extend(args) logging.debug("espeakng: executing %s" % repr(cmd)) - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) res = iter(p.stdout.readline, b"") if not sync: p.stdout.close() @@ -110,20 +114,20 @@ class ESpeak(BasePhonemizer): args = ["-v", f"{self._language}"] # espeak and espeak-ng parses `ipa` differently if tie: - # use '͡' between phonemes - if _DEF_ESPEAK_LIB == "espeak": - args.append("--ipa=1") - else: - args.append("--ipa=3") + # use '͡' between phonemes + if _DEF_ESPEAK_LIB == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") else: - # split with '_' - if _DEF_ESPEAK_LIB == "espeak": - args.append("--ipa=3") - else: - args.append("--ipa=1") + # split with '_' + if _DEF_ESPEAK_LIB == "espeak": + args.append("--ipa=3") + else: + args.append("--ipa=1") if tie: args.append("--tie=%s" % tie) - args.append(text) + args.append('"' + text + '"') # compute phonemes phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index b4e3652e..e56d1db4 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -364,4 +364,4 @@ class GAN(BaseVocoder): @staticmethod def init_from_config(config: Coqpit) -> "GAN": - return GAN(config) \ No newline at end of file + return GAN(config) From 04202da1ac1dc6480e7c2d96045809d72d6734c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:48:47 +0100 Subject: [PATCH 070/214] Make style --- TTS/config/shared_configs.py | 2 +- TTS/model.py | 1 - TTS/tts/models/glow_tts.py | 5 ++++- TTS/tts/utils/text/tokenizer.py | 20 +++++++++++++++++--- tests/text_tests/test_tokenizer.py | 4 ++-- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index f2bd40ad..217282ad 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -291,7 +291,7 @@ class BaseTrainingConfig(Coqpit): log_model_step (int): Number of steps required to log a checkpoint as W&B artifact - save_step (int):ipt + save_step (int): Number of steps required to save the next checkpoint. checkpoint (bool): diff --git a/TTS/model.py b/TTS/model.py index efa00b2a..6ce11e63 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -159,4 +159,3 @@ class BaseModel(nn.Module, ABC): def format_batch(self): pass - diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 3e5226aa..73680f32 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -522,14 +522,17 @@ class GlowTTS(BaseTTS): # init characters if config.use_phonemes: from TTS.tts.utils.text.characters import IPAPhonemes + characters = IPAPhonemes().init_from_config(config) else: from TTS.tts.utils.text.characters import Graphemes + characters = Graphemes().init_from_config(config) config.num_chars = characters.num_chars from TTS.utils.audio import AudioProcessor + ap = AudioProcessor.init_from_config(config) tokenizer = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config) - return GlowTTS(config, ap, tokenizer, speaker_manager) \ No newline at end of file + return GlowTTS(config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index ada5e57b..fac430f0 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -42,7 +42,7 @@ class TTSTokenizer: add_blank: bool = False, use_eos_bos=False, ): - self.text_cleaner = text_cleaner or (lambda x: x) + self.text_cleaner = text_cleaner self.use_phonemes = use_phonemes self.add_blank = add_blank self.use_eos_bos = use_eos_bos @@ -50,6 +50,16 @@ class TTSTokenizer: self.not_found_characters = [] self.phonemizer = phonemizer + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, new_characters): + self._characters = new_characters + self.pad_id = self.characters.char_to_id(self.characters.pad) + self.blank_id = self.characters.char_to_id(self.characters.blank) + def encode(self, text: str) -> List[int]: """Encodes a string of text as a sequence of IDs.""" token_ids = [] @@ -61,6 +71,7 @@ class TTSTokenizer: # discard but store not found characters if char not in self.not_found_characters: self.not_found_characters.append(char) + print(text) print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") return token_ids @@ -88,7 +99,8 @@ class TTSTokenizer: 5. Text to token IDs """ # TODO: text cleaner should pick the right routine based on the language - text = self.text_cleaner(text) + if self.text_cleaner is not None: + text = self.text_cleaner(text) if self.use_phonemes: text = self.phonemizer.phonemize(text, separator="") if self.add_blank: @@ -144,7 +156,9 @@ class TTSTokenizer: if "phonemizer" in config and config.phonemizer: phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) else: - phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs) + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) else: # init character set characters = Graphemes().init_from_config(config) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 6c48d276..4d3fb0ce 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -56,10 +56,10 @@ class TestTTSTokenizer(unittest.TestCase): self.ph = ESpeak("en-us") self.tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) self.assertEqual(len(self.tokenizer.not_found_characters), 0) - text = "Yolk of one egg beaten light" + text = "Yolk of one egg beaten light" ids = self.tokenizer_local.text_to_ids(text) text_hat = self.tokenizer_local.ids_to_text(ids) - self.assertEqual(self.tokenizer_local.not_found_characters, ['̩']) + self.assertEqual(self.tokenizer_local.not_found_characters, ["̩"]) self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt") def test_init_from_config(self): From a71a013276c95e9ce2143b597da364b0dbefa165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:52:32 +0100 Subject: [PATCH 071/214] Fix the wrong default loss name for GAN models --- TTS/vocoder/configs/shared_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py index 9ff6f790..a558cfca 100644 --- a/TTS/vocoder/configs/shared_configs.py +++ b/TTS/vocoder/configs/shared_configs.py @@ -99,7 +99,7 @@ class BaseGANVocoderConfig(BaseVocoderConfig): "mel_fmax": None, }` target_loss (str): - Target loss name that defines the quality of the model. Defaults to `avg_G_loss`. + Target loss name that defines the quality of the model. Defaults to `G_avg_loss`. grad_clip (list): A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping. Defaults to [5, 5]. From 9397a56b1325f18ad4e017d1fe81b7293a820e36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:53:00 +0100 Subject: [PATCH 072/214] Allow init_from_config from model or audio config --- TTS/utils/audio.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index ee255f44..bdee8615 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -381,7 +381,10 @@ class AudioProcessor(object): @staticmethod def init_from_config(config: "Coqpit"): - return AudioProcessor(**config.audio) + if "audio" in config: + return AudioProcessor(**config.audio) + else: + return AudioProcessor(**config) ### setting up the parameters ### def _build_mel_basis( @@ -729,6 +732,7 @@ class AudioProcessor(object): >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> pitch = ap.compute_f0(wav) """ + assert self.mel_fmax is not None, " [!] Set `mel_fmax` before caling `compute_f0`." # align F0 length to the spectrogram length if len(x) % self.hop_length == 0: x = np.pad(x, (0, self.hop_length // 2), mode="reflect") From 3476be30d78b9618bfa319366663165d1f1f9445 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:29:57 +0100 Subject: [PATCH 073/214] Refactor Synthesizer class for TTSTokenizer --- TTS/utils/synthesizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2e4f4735..a06a493f 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,7 +114,8 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.tts_model = setup_tts_model(config=self.tts_config) + self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) + self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -332,6 +333,8 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, + ap=self.ap, + tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, From d0eb642d884058baca4b6b7fe2613114893c6ed7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:34:45 +0100 Subject: [PATCH 074/214] Refactor synthesis.py for TTSTokenizer --- TTS/tts/utils/synthesis.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 9d9660aa..7bbc282f 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -113,6 +113,11 @@ def synthesis( text, CONFIG, use_cuda, +<<<<<<< HEAD +======= + ap, + tokenizer, +>>>>>>> Refactor synthesis.py for TTSTokenizer speaker_id=None, style_wav=None, use_griffin_lim=False, @@ -164,10 +169,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) # convert text to sequence of token IDs text_inputs = np.asarray( - model.tokenizer.text_to_ids(text), + tokenizer.text_to_ids(text), dtype=np.int32, ) # pass tensors to backend From 9a95e154839d5018cc5ab4adf3fb99916e926cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:36:35 +0100 Subject: [PATCH 075/214] Refactor GlowTTS model and recipe for TTSTokenizer --- TTS/tts/models/base_tts.py | 7 +++---- recipes/ljspeech/glow_tts/train_glowtts.py | 7 ++++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 8bb7f02e..01f4a1de 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,6 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.symbols import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -32,9 +33,7 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ - def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None - ): + def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): super().__init__(config) self.config = config self.ap = ap @@ -292,7 +291,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer, + tokenizer=self.tokenizer ) # pre-compute phonemes diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 4762a77a..fe4a9d9b 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -71,7 +71,12 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples ) # AND... 3,2,1... 🚀 From 2d8ce98d2a68a3f7d9df4b12393448b01a42a7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 17 Nov 2021 12:46:04 +0100 Subject: [PATCH 076/214] Update imports for symbols -> characters --- TTS/tts/models/base_tts.py | 11 +++++------ recipes/ljspeech/glow_tts/train_glowtts.py | 7 +------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 01f4a1de..493c8869 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,7 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import Graphemes, make_symbols +from TTS.tts.utils.text.characters import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -33,7 +33,9 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ - def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + def __init__( + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None + ): super().__init__(config) self.config = config self.ap = ap @@ -71,9 +73,6 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: - return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit, data: List = None): """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -291,7 +290,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer + tokenizer=self.tokenizer, ) # pre-compute phonemes diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index fe4a9d9b..4762a77a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 From 1df1d6c4a9a2c9a9f4eb42442a98143d070b3adc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 17:49:20 +0100 Subject: [PATCH 077/214] Update for tokenizer API --- TTS/utils/synthesizer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a06a493f..2e4f4735 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,8 +114,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) + self.tts_model = setup_tts_model(config=self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -333,8 +332,6 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, - ap=self.ap, - tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, From 4597d4e5b6b8c034a6e7e9d1f24be3f859eb3a9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:43:48 +0100 Subject: [PATCH 078/214] Remove get_characters from BaseTTS --- TTS/tts/models/base_tts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 493c8869..b9b4ed57 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,7 +14,6 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.characters import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file From 176b712c1a40cf630da9a77f1826836723c40fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:50:18 +0100 Subject: [PATCH 079/214] =?UTF-8?q?Refactor=20TTSDataset=20=E2=9A=A1?= =?UTF-8?q?=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/tts/datasets/dataset.py | 684 ++++++++++++++++++++++++------------ 1 file changed, 451 insertions(+), 233 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 8c21d7d0..60b514c2 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -2,7 +2,7 @@ import collections import os import random from multiprocessing import Pool -from typing import Dict, List +from typing import Dict, List, Union import numpy as np import torch @@ -14,6 +14,24 @@ from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + class TTSDataset(Dataset): def __init__( self, @@ -26,9 +44,12 @@ class TTSDataset(Dataset): f0_cache_path: str = None, return_wav: bool = False, batch_group_size: int = 0, - min_seq_len: int = 0, - max_seq_len: int = float("inf"), + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), phoneme_cache_path: str = None, + precompute_num_workers: int = 0, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, @@ -37,7 +58,7 @@ class TTSDataset(Dataset): ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can inherit and override. + If you need something different, you can subclass and override. Args: outputs_per_step (int): Number of time frames predicted per step. @@ -61,17 +82,24 @@ class TTSDataset(Dataset): sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a batch. Set 0 to disable. Defaults to 0. - min_seq_len (int): Minimum input sequence length to be processed - by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a - minimum input length due to its architecture. Defaults to 0. + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. - max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. - It helps for controlling the VRAM usage against long input sequences. Especially models with - RNN layers are sensitive to input length. Defaults to `Inf`. + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. @@ -83,15 +111,17 @@ class TTSDataset(Dataset): """ super().__init__() self.batch_group_size = batch_group_size - self.items = meta_data + self._samples = meta_data self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 self.f0_cache_path = f0_cache_path - self.min_seq_len = min_seq_len - self.max_seq_len = max_seq_len + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len self.ap = ap self.phoneme_cache_path = phoneme_cache_path self.speaker_id_mapping = speaker_id_mapping @@ -100,112 +130,113 @@ class TTSDataset(Dataset): self.use_noise_augment = use_noise_augment self.verbose = verbose - self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer - if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path): - os.makedirs(phoneme_cache_path, exist_ok=True) + self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples) + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) + if compute_f0: - self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) if self.verbose: self.print_logs() + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + def print_logs(self, level: int = 0) -> None: indent = "\t" * level print("\n") print(f"{indent}> DataLoader initialization") print(f"{indent}| > Tokenizer:") self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.items)}") + print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - audio = self.ap.load_wav(filename) - return audio + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform - @staticmethod - def load_np(filename): - data = np.load(filename).astype("float32") - return data + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert out_dict["token_ids"].size > 0 + return out_dict - @staticmethod - def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path): - """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and - eos chars here. Instead we add those dynamically later; based on the - config option.""" - phonemes = tokenizer.text_to_ids(text) - phonemes = np.asarray(phonemes, dtype=np.int32) - np.save(cache_path, phonemes) - return phonemes + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + _, wav_file, *_ = _parse_sample(self.samples[idx]) + assert wav_file == out_dict["audio_file"] + return out_dict - @staticmethod - def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] + def get_attn_maks(self, attn_file): + return np.load(attn_file) - # different names for normal phonemes and with blank chars. - file_name_ext = "_phoneme.npy" - cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) - try: - phonemes = np.load(cache_path) - except FileNotFoundError: - phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) - except (ValueError, IOError): - print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) - phonemes = np.asarray(phonemes, dtype=np.int32) - return phonemes + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return token_ids def load_data(self, idx): - item = self.items[idx] + item = self.samples[idx] + raw_text = item["text"] - wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + wav = np.asarray(self.load_wav(item[]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: - wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + wav = noise_augment_audio(wav) - if not self.input_seq_computed: - if self.tokenizer.use_phonemes: - text = self._load_or_generate_phoneme_sequence( - item["audio_file"], - item["text"], - item["language"] if item["language"] else self.phoneme_language, - self.tokenizer, - self.phoneme_cache_path, - ) - else: - text = np.asarray( - self.tokenizer.text_to_ids(item["text"], item["language"]), - dtype=np.int32, - ) - - assert text.size > 0, self.items[idx]["audio_file"] - assert wav.size > 0, self.items[idx]["audio_file"] + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + # get pre-computed attention maps attn = None if "alignment_file" in item: - attn = np.load(item["alignment_file"]) + attn = self.get_attn_mask(item["alignment_file"]) - if len(text) > self.max_seq_len: - # return a different sample if the phonemized - # text is longer than the threshold - # TODO: find a better fix + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len: return self.load_data(self.rescue_item_idx) - pitch = None + # get f0 values + f0 = None if self.compute_f0: - pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path) - pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) + f0 = self.get_f0(idx)["f0"] sample = { "raw_text": raw_text, - "text": text, + "token_ids": token_ids, "wav": wav, - "pitch": pitch, + "pitch": f0, "attn": attn, "item_idx": item["audio_file"], "speaker_name": item["speaker_name"], @@ -215,105 +246,78 @@ class TTSDataset(Dataset): return sample @staticmethod - def _phoneme_worker(args): - item = args[0] - func_args = args[1] - func_args[3] = ( - item["language"] if "language" in item and item["language"] else func_args[3] - ) # override phoneme language if specified by the dataset formatter - phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args) - return phonemes + def compute_lengths(samples): + audio_lengths = [] + text_lengths = [] + for item in samples: + text, wav_file, *_ = _parse_sample(item) + audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio + text_lengths.append(len(text)) + audio_lengths = np.array(audio_lengths) + text_lengths = np.array(text_lengths) + return audio_lengths, text_lengths - def compute_input_seq(self, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not self.use_phonemes: - if self.verbose: - print(" | > Computing input sequences ...") - for idx, item in enumerate(tqdm.tqdm(self.items)): - sequence = np.asarray( - self.tokenizer.text_to_ids(item["text"]), - dtype=np.int32, - ) - self.items[idx][0] = sequence - else: - func_args = [ - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, - self.phoneme_language, - self.characters, - self.add_blank, - ] - if self.verbose: - print(" | > Computing phonemes ...") - if num_workers == 0: - for idx, item in enumerate(tqdm.tqdm(self.items)): - phonemes = self._phoneme_worker([item, func_args]) - self.items[idx][0] = phonemes - else: - with Pool(num_workers) as p: - phonemes = list( - tqdm.tqdm( - p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]), - total=len(self.items), - ) - ) - for idx, p in enumerate(phonemes): - self.items[idx][0] = p - - def sort_and_filter_items(self, by_audio_len=False): - r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length - range. - - Args: - by_audio_len (bool): if True, sort by audio length else by text length. - """ - # compute the target sequence length - if by_audio_len: - lengths = [] - for item in self.items: - lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio - lengths = np.array(lengths) - else: - lengths = np.array([len(ins["text"]) for ins in self.items]) - - idxs = np.argsort(lengths) - new_items = [] - ignored = [] + @staticmethod + def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] for i, idx in enumerate(idxs): length = lengths[idx] - if length < self.min_seq_len or length > self.max_seq_len: - ignored.append(idx) + if length < min_len or length > max_len: + ignore_idx.append(idx) else: - new_items.append(self.items[idx]) + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def create_buckets(samples, batch_group_size:int): + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + def preprocess_samples(self): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + + # sort items based on the sequence length in ascending order + text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = [] + for idx in keep_idx: + samples.append(self.samples[idx]) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") + # shuffle batch groups - if self.batch_group_size > 0: - for i in range(len(new_items) // self.batch_group_size): - offset = i * self.batch_group_size - end_offset = offset + self.batch_group_size - temp_items = new_items[offset:end_offset] - random.shuffle(temp_items) - new_items[offset:end_offset] = temp_items - self.items = new_items + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + self.samples = samples if self.verbose: - print(" | > Max length sequence: {}".format(np.max(lengths))) - print(" | > Min length sequence: {}".format(np.min(lengths))) - print(" | > Avg length sequence: {}".format(np.mean(lengths))) - print( - " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( - self.max_seq_len, self.min_seq_len, len(ignored) - ) - ) + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(self.text_lengths))) + print(" | > Min text length: {}".format(np.min(self.text_lengths))) + print(" | > Avg text length: {}".format(np.mean(self.text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(self.audio_lengths))) + print(" | > Min audio length: {}".format(np.min(self.audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size)) - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.load_data(idx) - @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. @@ -338,10 +342,10 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lengths = np.array([len(d["text"]) for d in batch]) + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) # sort items with text input length for RNN efficiency - batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} @@ -383,7 +387,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(batch["text"]).astype(np.int32) + text = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -392,12 +396,13 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lengths = torch.LongTensor(text_lengths) + token_ids_lengths = torch.LongTensor(token_ids_lengths) text = torch.LongTensor(text) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) + # speaker vectors if d_vectors is not None: d_vectors = torch.FloatTensor(d_vectors) @@ -408,14 +413,13 @@ class TTSDataset(Dataset): language_ids = torch.LongTensor(language_ids) # compute linear spectrogram + linear = None if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] linear = torch.FloatTensor(linear).contiguous() - else: - linear = None # format waveforms wav_padded = None @@ -431,8 +435,7 @@ class TTSDataset(Dataset): wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) - # compute f0 - # TODO: compare perf in collate_fn vs in load_data + # format F0 if self.compute_f0: pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" @@ -440,7 +443,8 @@ class TTSDataset(Dataset): else: pitch = None - # collate attention alignments + # format attention masks + attns = None if batch["attn"][0] is not None: attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): @@ -451,12 +455,10 @@ class TTSDataset(Dataset): attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) attns = torch.FloatTensor(attns).unsqueeze(1) - else: - attns = None - # TODO: return dictionary + return { - "text": text, - "text_lengths": text_lengths, + "token_id": text, + "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, @@ -482,22 +484,179 @@ class TTSDataset(Dataset): ) -class PitchExtractor: - """Pitch Extractor for computing F0 from wav files. +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + Args: - items (List[List]): Dataset samples. - verbose (bool): Whether to print the progress. + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. """ def __init__( self, - items: List[Dict], - verbose=False, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, ): - self.items = items + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + text, wav_file, *_ = _parse_sample(self.samples[index]) + ids = self.compute_or_load(wav_file, text) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, wav_file, text): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_name = os.path.splitext(os.path.basename(wav_file))[0] + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 self.mean = None self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + _, wav_file, *_ = _parse_sample(self.samples[idx]) + f0 = self.compute_or_load(wav_file) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_file": wav_file, "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append([f for f in f0]) + pbar.update(batch_size) + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id @staticmethod def create_pitch_file_path(wav_file, cache_path): @@ -519,69 +678,128 @@ class PitchExtractor: mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def normalize_pitch(self, pitch): + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch = pitch - self.mean pitch = pitch / self.std pitch[zero_idxs] = 0.0 return pitch - def denormalize_pitch(self, pitch): + def denormalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch *= self.std pitch += self.mean pitch[zero_idxs] = 0.0 return pitch - @staticmethod - def load_or_compute_pitch(ap, wav_file, cache_path): + def compute_or_load(self, wav_file): """ compute pitch and return a numpy array of pitch values """ - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path) - if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file) - return pitch - return None + def collate_fn(self, batch): + audio_file = [item["audio_file"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} - def compute_pitch(self, ap, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) +if __name__ == "__main__": + from torch.utils.data import DataLoader + + from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig + from TTS.tts.datasets import load_tts_samples + from TTS.tts.utils.text.characters import IPAPhonemes + from TTS.tts.utils.text.phonemizers import ESpeak + + dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1", + ) + train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + samples = train_samples + eval_samples + + phonemizer = ESpeak(language="en-us") + tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer) + # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests") + # ph_dataset.precompute(num_workers=4) + + # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn) + # for batch in dataloader: + # print(batch) + # break + + audio_config = BaseAudioConfig( + sample_rate=22050, + win_length=1024, + hop_length=256, + num_mels=80, + preemphasis=0.0, + ref_level_db=20, + log_func="np.log", + do_trim_silence=True, + trim_db=45, + mel_fmin=0, + mel_fmax=8000, + spec_gain=1.0, + signal_norm=False, + do_amp_to_db_linear=False, + ) + + ap = AudioProcessor.init_from_config(audio_config) + + # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4) + + # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn) + # for batch in dataloader: + # print(batch) + # breakpoint() + # break + + dataset = TTSDataset( + outputs_per_step=1, + compute_linear_spec=False, + meta_data=samples, + ap=ap, + return_wav=False, + batch_group_size=0, + min_seq_len=0, + max_seq_len=500, + use_noise_augment=False, + verbose=True, + speaker_id_mapping=None, + d_vector_mapping=None, + compute_f0=True, + f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests", + tokenizer=tokenizer, + phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests", + precompute_num_workers=4, + ) + + dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) + for batch in dataloader: + print(batch) + break From 4cd690e4c17802195d5e2d5340278a1bf9bd2af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:52:01 +0100 Subject: [PATCH 080/214] Updates BaseTTS and configs --- TTS/tts/configs/shared_configs.py | 22 +++++++++---- TTS/tts/models/base_tts.py | 51 +++++-------------------------- 2 files changed, 24 insertions(+), 49 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index b101b70a..98461bdd 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -146,11 +146,19 @@ class BaseTTSConfig(BaseTrainingConfig): sort_by_audio_len (bool): If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. - min_seq_len (int): - Minimum sequence length to be used at training. + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. - max_seq_len (int): - Maximum sequence length to be used at training. Larger values result in more VRAM usage. + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). compute_f0 (int): (Not in use yet). @@ -211,8 +219,10 @@ class BaseTTSConfig(BaseTrainingConfig): loss_masking: bool = None # dataloading sort_by_audio_len: bool = False - min_seq_len: int = 1 - max_seq_len: int = float("inf") + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") compute_f0: bool = False compute_linear_spec: bool = False use_noise_augment: bool = False diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index b9b4ed57..27231790 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -168,8 +168,8 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch["text"] - text_lengths = batch["text_lengths"] + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] speaker_names = batch["speaker_names"] linear_input = batch["linear"] mel_input = batch["mel"] @@ -261,10 +261,6 @@ class BaseTTS(BaseModel): d_vector_mapping = None # setup custom symbols if needed - custom_symbols = None - if hasattr(self, "make_symbols"): - custom_symbols = self.make_symbols(self.config) - if hasattr(self, "language_manager"): language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None @@ -282,8 +278,10 @@ class BaseTTS(BaseModel): ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, @@ -292,45 +290,12 @@ class BaseTTS(BaseModel): tokenizer=self.tokenizer, ) - # pre-compute phonemes - if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]: - if hasattr(self, "eval_data_items") and is_eval: - dataset.items = self.eval_data_items - elif hasattr(self, "train_data_items") and not is_eval: - dataset.items = self.train_data_items - else: - # precompute phonemes for precise estimate of sequence lengths. - # otherwise `dataset.sort_items()` uses raw text lengths - dataset.compute_input_seq(config.num_loader_workers) - - # TODO: find a more efficient solution - # cheap hack - store items in the model state to avoid recomputing when reinit the dataset - if is_eval: - self.eval_data_items = dataset.items - else: - self.train_data_items = dataset.items - - # halt DDP processes for the main process to finish computing the phoneme cache + # wait all the DDP process to be ready if num_gpus > 1: dist.barrier() # sort input sequences from short to long - dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) - - # compute pitch frames and write to files. - if config.compute_f0 and rank in [None, 0]: - if not os.path.exists(config.f0_cache_path): - dataset.pitch_extractor.compute_pitch( - self.ap, config.get("f0_cache_path", None), config.num_loader_workers - ) - - # halt DDP processes for the main process to finish computing the F0 cache - if num_gpus > 1: - dist.barrier() - - # load pitch stats computed above by all the workers - if config.compute_f0: - dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.preprocess_samples() # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None From 7575367b9f386639acd56af630fda2b9bb8b5136 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:55:36 +0100 Subject: [PATCH 081/214] Refactorin VITS for the tokenizer API --- TTS/tts/models/vits.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ae24a99e..23be6177 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -19,6 +19,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator @@ -280,19 +281,15 @@ class Vits(BaseTTS): language_manager: LanguageManager = None, ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) self.END2END = True self.speaker_manager = speaker_manager self.language_manager = language_manager if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig - if "num_chars" not in config: - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - else: - self.config = config - config.model_args.num_chars = config.num_chars + self.num_chars = self.tokenizer.characters.num_chars + self.config = config args = self.config.model_args elif isinstance(config, VitsArgs): # loading from VitsArgs @@ -1039,3 +1036,25 @@ class Vits(BaseTTS): if eval: self.eval() assert not self.training + + @staticmethod + def init_from_config(config: "Coqpit"): + """Initialize model from config.""" + + # init characters + if config.use_phonemes: + from TTS.tts.utils.text.characters import IPAPhonemes + + characters = IPAPhonemes().init_from_config(config) + else: + from TTS.tts.utils.text.characters import Graphemes + + characters = Graphemes().init_from_config(config) + config.num_chars = characters.num_chars + + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return Vits(config, ap, tokenizer, speaker_manager) From 98057a00ae248e629671ccc34d4acd46cb3fd8ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:56:15 +0100 Subject: [PATCH 082/214] Make style --- TTS/tts/datasets/dataset.py | 12 ++++++++---- TTS/tts/utils/text/characters.py | 6 +++--- TTS/tts/utils/text/phonemizers/__init__.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 60b514c2..337dcfa5 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -258,7 +258,7 @@ class TTSDataset(Dataset): return audio_lengths, text_lengths @staticmethod - def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int): + def sort_and_filter_by_length(lengths: List[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -271,7 +271,7 @@ class TTSDataset(Dataset): return ignore_idx, keep_idx @staticmethod - def create_buckets(samples, batch_group_size:int): + def create_buckets(samples, batch_group_size: int): for i in range(len(samples) // batch_group_size): offset = i * batch_group_size end_offset = offset + batch_group_size @@ -286,8 +286,12 @@ class TTSDataset(Dataset): """ # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) - audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len) + text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length( + self.text_lengths, self.min_text_len, self.max_text_len + ) + audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length( + self.audio_lengths, self.min_audio_len, self.max_audio_len + ) keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f9c44a7d..24ce51f1 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -209,7 +209,7 @@ class BaseCharacters: ), f" [!] There are duplicate characters in the character set. {set([x for x in self.vocab if self.vocab.count(x) > 1])}" def char_to_id(self, char: str) -> int: - return self._char_to_id[char] + return self._char_to_id[char] def id_to_char(self, idx: int) -> str: return self._id_to_char[idx] @@ -298,8 +298,8 @@ class IPAPhonemes(BaseCharacters): ) else: return IPAPhonemes( - **config.characters if config.characters is not None else {}, - ) + **config.characters if config.characters is not None else {}, + ) class Graphemes(BaseCharacters): diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index b00f7f5e..0da5875e 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -52,4 +52,4 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: if __name__ == "__main__": - print(DEF_LANG_TO_PHONEMIZER) \ No newline at end of file + print(DEF_LANG_TO_PHONEMIZER) From 75c507c36a923f61a9c8e3869c7be5c96bd94773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:57:12 +0100 Subject: [PATCH 083/214] Update VITS LJspeech recipe --- recipes/ljspeech/vits_tts/train_vits.py | 22 ++++++++++++++----- .../test_vocoder_multiband_melgan_config.json | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index e86cc861..0588e9d9 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -6,6 +6,7 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -35,7 +36,7 @@ config = VitsConfig( batch_size=48, eval_batch_size=16, batch_group_size=5, - num_loader_workers=4, + num_loader_workers=0, num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, @@ -53,14 +54,24 @@ config = VitsConfig( datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +tokenizer = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = Vits(config) +model = Vits(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( @@ -70,6 +81,5 @@ trainer = Trainer( model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, ) trainer.fit() diff --git a/tests/inputs/test_vocoder_multiband_melgan_config.json b/tests/inputs/test_vocoder_multiband_melgan_config.json index b8b192e4..82afc977 100644 --- a/tests/inputs/test_vocoder_multiband_melgan_config.json +++ b/tests/inputs/test_vocoder_multiband_melgan_config.json @@ -86,7 +86,7 @@ "mel_fmax": null }, - "target_loss": "avg_G_loss", // loss value to pick the best model to save after each epoch + "target_loss": "G_avg_loss", // loss value to pick the best model to save after each epoch // DISCRIMINATOR "discriminator_model": "melgan_multiscale_discriminator", From 196ae74273450beaa61633c030edc9587778e41c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 1 Dec 2021 10:06:02 +0100 Subject: [PATCH 084/214] Update data loader tests --- TTS/tts/datasets/dataset.py | 50 +++++++---- TTS/tts/utils/text/tokenizer.py | 1 + tests/data_tests/test_loader.py | 143 ++++++++++++++------------------ 3 files changed, 97 insertions(+), 97 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 337dcfa5..d4a12c07 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -38,7 +38,7 @@ class TTSDataset(Dataset): outputs_per_step: int, compute_linear_spec: bool, ap: AudioProcessor, - meta_data: List[Dict], + samples: List[Dict], tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, @@ -67,7 +67,7 @@ class TTSDataset(Dataset): ap (TTS.tts.utils.AudioProcessor): Audio processor object. - meta_data (list): List of dataset samples. + samples (list): List of dataset samples. tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. @@ -111,7 +111,7 @@ class TTSDataset(Dataset): """ super().__init__() self.batch_group_size = batch_group_size - self._samples = meta_data + self._samples = samples self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.compute_linear_spec = compute_linear_spec @@ -200,7 +200,7 @@ class TTSDataset(Dataset): token_ids = self.get_phonemes(idx, text)["token_ids"] else: token_ids = self.tokenizer.text_to_ids(text) - return token_ids + return np.array(token_ids, dtype=np.int32) def load_data(self, idx): item = self.samples[idx] @@ -258,7 +258,7 @@ class TTSDataset(Dataset): return audio_lengths, text_lengths @staticmethod - def sort_and_filter_by_length(lengths: List[int], min_len: int, max_len: int): + def filter_by_length(lengths: List[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -270,6 +270,11 @@ class TTSDataset(Dataset): keep_idx.append(idx) return ignore_idx, keep_idx + @staticmethod + def sort_by_length(lengths: List[int]): + idxs = np.argsort(lengths) # ascending order + return idxs + @staticmethod def create_buckets(samples, batch_group_size: int): for i in range(len(samples) // batch_group_size): @@ -280,24 +285,33 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples + def select_samples_by_idx(self, idxs): + samples = [] + audio_lengths = [] + text_lengths = [] + for idx in idxs: + samples.append(self.samples[idx]) + audio_lengths.append(self.audio_lengths[idx]) + text_lengths.append(self.text_lengths[idx]) + return samples, audio_lengths, text_lengths + def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length( - self.text_lengths, self.min_text_len, self.max_text_len - ) - audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length( + text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length( self.audio_lengths, self.min_audio_len, self.max_audio_len ) keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) - samples = [] - for idx in keep_idx: - samples.append(self.samples[idx]) + samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) + + sorted_idxs = self.sort_by_length(audio_lengths) + samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) if len(samples) == 0: raise RuntimeError(" [!] No samples left") @@ -309,6 +323,8 @@ class TTSDataset(Dataset): # update items to the new sorted items self.samples = samples + self.audio_lengths = audio_lengths + self.text_lengths = text_lengtsh if self.verbose: print(" | > Preprocessing samples") @@ -391,7 +407,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(batch["token_ids"]).astype(np.int32) + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -401,7 +417,7 @@ class TTSDataset(Dataset): # convert things to pytorch token_ids_lengths = torch.LongTensor(token_ids_lengths) - text = torch.LongTensor(text) + token_ids = torch.LongTensor(token_ids) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) @@ -453,7 +469,7 @@ class TTSDataset(Dataset): attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] - pad1 = text.shape[1] - attn.shape[0] + pad1 = token_ids.shape[1] - attn.shape[0] assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn @@ -461,7 +477,7 @@ class TTSDataset(Dataset): attns = torch.FloatTensor(attns).unsqueeze(1) return { - "token_id": text, + "token_id": token_ids, "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, @@ -786,7 +802,7 @@ if __name__ == "__main__": dataset = TTSDataset( outputs_per_step=1, compute_linear_spec=False, - meta_data=samples, + samples=samples, ap=ap, return_wav=False, batch_group_size=0, diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index fac430f0..68a1c575 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -147,6 +147,7 @@ class TTSTokenizer: if isinstance(config.text_cleaner, (str, list)): text_cleaner = getattr(cleaners, config.text_cleaner) + phonemizer = None if config.use_phonemes: # init phoneme set characters = IPAPhonemes().init_from_config(config) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index d210995d..712e59e3 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,9 +7,9 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.configs.shared_configs import BaseTTSConfig, BaseDatasetConfig from TTS.tts.datasets import TTSDataset, load_tts_samples -from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # pylint: disable=unused-variable @@ -50,18 +50,19 @@ class TestTTSDataset(unittest.TestCase): meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2) items = meta_data_train + meta_data_eval + tokenizer = TTSTokenizer.init_from_config(c) dataset = TTSDataset( - r, - c.text_cleaner, + outputs_per_step=r, compute_linear_spec=True, return_wav=True, + tokenizer=tokenizer, ap=self.ap, - meta_data=items, - characters=c.characters, + samples=items, batch_group_size=bgs, - min_seq_len=c.min_seq_len, - max_seq_len=float("inf"), - use_phonemes=False, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, ) dataloader = DataLoader( dataset, @@ -80,27 +81,26 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] + text_input = data["token_id"] + _ = data["token_id_lengths"] speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] + _ = data["stop_targets"] + _ = data["item_idxs"] wavs = data["waveform"] neg_values = text_input[text_input < 0] check_count = len(neg_values) - assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) - assert isinstance(speaker_name[0], str) - assert linear_input.shape[0] == c.batch_size - assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio["num_mels"] - assert ( - wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length - ), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}" + + # check basic conditions + self.assertEqual(check_count, 0) + self.assertEqual(linear_input.shape[0], mel_input.shape[0], c.batch_size) + self.assertEqual(linear_input.shape[2], self.ap.fft_size // 2 + 1) + self.assertEqual(mel_input.shape[2], c.audio["num_mels"]) + self.assertEqual(wavs.shape[1], mel_input.shape[1] * c.audio.hop_length) + self.assertIsInstance(speaker_name[0], str) # make sure that the computed mels and the waveform match and correctly computed mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) @@ -109,55 +109,58 @@ class TestTTSDataset(unittest.TestCase): # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding mel_new = mel_new[:, :mel_lengths[0]] ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) - mel_diff = (mel_new - mel_dataloader)[:, 0:ignore_seg] - assert abs(mel_diff.sum()) < 1e-5 + mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] + self.assertLess(abs(mel_diff.sum()), 1e-5) # check normalization ranges if self.ap.symmetric_norm: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type - assert mel_input.min() < 0 + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual( + mel_input.min(), -self.ap.max_norm + ) # pylint: disable=invalid-unary-operand-type + self.assertLess(mel_input.min(), 0) else: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= 0 + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual(mel_input.min(), 0) def test_batch_group_shuffle(self): if ok_ljspeech: dataloader, dataset = self._create_dataloader(2, c.r, 16) last_length = 0 - frames = dataset.items + frames = dataset.samples for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] - linear_input = data["linear"] - mel_input = data["mel"] mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] - avg_length = mel_lengths.numpy().mean() - assert avg_length >= last_length - dataloader.dataset.sort_and_filter_items() + dataloader.dataset.preprocess_samples() is_items_reordered = False - for idx, item in enumerate(dataloader.dataset.items): + for idx, item in enumerate(dataloader.dataset.samples): if item != frames[idx]: is_items_reordered = True break - assert is_items_reordered + self.assertGreaterEqual(avg_length, last_length) + self.assertTrue(is_items_reordered) + + def test_padding_and_spectrograms(self): + def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): + self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding + self.assertNotEqual(linear_input[idx, -2].sum(), 0) + self.assertNotEqual(mel_input[idx, -1].sum(), 0) + self.assertNotEqual(mel_input[idx, -2].sum(), 0) + self.assertEqual(stop_target[idx, -1], 1) + self.assertEqual(stop_target[idx, -2], 0) + self.assertEqual(stop_target[idx].sum(), 1) + self.assertEqual(len(mel_lengths.shape), 1) + self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0]) + self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0]) - def test_padding_and_spec(self): if ok_ljspeech: - dataloader, dataset = self._create_dataloader(1, 1, 0) + dataloader, _ = self._create_dataloader(1, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] @@ -172,7 +175,7 @@ class TestTTSDataset(unittest.TestCase): # NOTE: Below needs to check == 0 but due to an unknown reason # there is a slight difference between two matrices. # TODO: Check this assert cond more in detail. - assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max() + self.assertLess(abs(mel.T - mel_dl).max(), 1e-5) # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() @@ -186,56 +189,36 @@ class TestTTSDataset(unittest.TestCase): self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav") shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav") - # check the last time step to be zero padded - assert linear_input[0, -1].sum() != 0 - assert linear_input[0, -2].sum() != 0 - assert mel_input[0, -1].sum() != 0 - assert mel_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == linear_input[0].shape[0] - assert mel_lengths[0] == mel_input[0].shape[0] + # check the outputs + check_conditions(0, linear_input, mel_input, stop_target, mel_lengths) # Test for batch size 2 - dataloader, dataset = self._create_dataloader(2, 1, 0) + dataloader, _ = self._create_dataloader(2, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] stop_target = data["stop_targets"] item_idx = data["item_idxs"] + # set id to the longest sequence in the batch if mel_lengths[0] > mel_lengths[1]: idx = 0 else: idx = 1 - # check the first item in the batch - assert linear_input[idx, -1].sum() != 0 - assert linear_input[idx, -2].sum() != 0, linear_input - assert mel_input[idx, -1].sum() != 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] - assert mel_lengths[idx] == linear_input[idx].shape[0] + # check the longer item in the batch + check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths) - # check the second itme in the batch - assert linear_input[1 - idx, -1].sum() == 0 - assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1, mel_lengths[1] - 1] == 1 - assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1] - assert len(mel_lengths.shape) == 1 + # check the other item in the batch + self.assertEqual(linear_input[1 - idx, -1].sum(), 0) + self.assertEqual(mel_input[1 - idx, -1].sum(), 0) + self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1) + self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1]) + self.assertEqual(len(mel_lengths.shape), 1) # check batch zero-frame conditions (zero-frame disabled) # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 From 84091096a65f9e69232376e4e6b4c13fbf37cc27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:29:57 +0100 Subject: [PATCH 085/214] Refactor Synthesizer class for TTSTokenizer --- TTS/utils/synthesizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2e4f4735..a06a493f 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,7 +114,8 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.tts_model = setup_tts_model(config=self.tts_config) + self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) + self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -332,6 +333,8 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, + ap=self.ap, + tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, From b2bb954a51bb4eb0a29d5205460ebfbae6449176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:33:21 +0100 Subject: [PATCH 086/214] Refactor TTSDataset to use TTSTokenizer --- TTS/tts/datasets/dataset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d4a12c07..9b78ddba 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -69,6 +69,9 @@ class TTSDataset(Dataset): samples (list): List of dataset samples. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. @@ -202,6 +205,20 @@ class TTSDataset(Dataset): token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) + @staticmethod + def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + def load_data(self, idx): item = self.samples[idx] From b6c2bfdf086c2404d78b198a8f4a34d352c015b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:34:45 +0100 Subject: [PATCH 087/214] Refactor synthesis.py for TTSTokenizer --- TTS/tts/utils/synthesis.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 7bbc282f..979769a8 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -113,11 +113,8 @@ def synthesis( text, CONFIG, use_cuda, -<<<<<<< HEAD -======= ap, tokenizer, ->>>>>>> Refactor synthesis.py for TTSTokenizer speaker_id=None, style_wav=None, use_griffin_lim=False, From 8071fa0020183d383868294b2f42a9bbebbb8851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:36:35 +0100 Subject: [PATCH 088/214] Refactor GlowTTS model and recipe for TTSTokenizer --- TTS/tts/models/base_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 27231790..64086a84 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -287,7 +287,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer, + tokenizer=self.tokenizer ) # wait all the DDP process to be ready From 452dbc43d82bc0dd78883ce510f521be3e2d9557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 17 Nov 2021 12:46:04 +0100 Subject: [PATCH 089/214] Update imports for symbols -> characters --- TTS/tts/models/base_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 64086a84..27231790 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -287,7 +287,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer + tokenizer=self.tokenizer, ) # wait all the DDP process to be ready From 9bb347a52b1bd3efd30b178b358005e825bc29ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 17:49:20 +0100 Subject: [PATCH 090/214] Update for tokenizer API --- TTS/utils/synthesizer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a06a493f..2e4f4735 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,8 +114,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) + self.tts_model = setup_tts_model(config=self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -333,8 +332,6 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, - ap=self.ap, - tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, From 04df0a3d9f4a38be0c4c5d617365003d115c9ed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:50:18 +0100 Subject: [PATCH 091/214] =?UTF-8?q?Refactor=20TTSDataset=20=E2=9A=A1?= =?UTF-8?q?=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/tts/datasets/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 9b78ddba..9de40c2b 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -216,8 +216,8 @@ class TTSDataset(Dataset): elif len(item) == 3: text, wav_file, speaker_name = item else: - raise ValueError(" [!] Dataset cannot parse the sample.") - return text, wav_file, speaker_name, language_name, attn_file + token_ids = self.tokenizer.text_to_ids(text) + return token_ids def load_data(self, idx): item = self.samples[idx] From 93957d58a100775feefea9eab6c38cd6a5d1869e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:55:36 +0100 Subject: [PATCH 092/214] Refactorin VITS for the tokenizer API --- TTS/tts/models/vits.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 23be6177..1de26913 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -275,10 +275,7 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value def __init__( - self, - config: Coqpit, - speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None ): super().__init__(config, ap, tokenizer, speaker_manager) From 90cc45dd4e942926c72172b3b5f6c06250091efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 1 Dec 2021 10:06:02 +0100 Subject: [PATCH 093/214] Update data loader tests --- TTS/tts/datasets/dataset.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 9de40c2b..d4a12c07 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -69,9 +69,6 @@ class TTSDataset(Dataset): samples (list): List of dataset samples. - tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else - use the given. Defaults to None. - tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. @@ -205,20 +202,6 @@ class TTSDataset(Dataset): token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) - @staticmethod - def _parse_sample(item): - language_name = None - attn_file = None - if len(item) == 5: - text, wav_file, speaker_name, language_name, attn_file = item - elif len(item) == 4: - text, wav_file, speaker_name, language_name = item - elif len(item) == 3: - text, wav_file, speaker_name = item - else: - token_ids = self.tokenizer.text_to_ids(text) - return token_ids - def load_data(self, idx): item = self.samples[idx] From 30cfafce569b1d8d09b6efe10583ba68f6ae7b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 08:56:57 +0000 Subject: [PATCH 094/214] Add init_from_config --- TTS/vocoder/models/base_vocoder.py | 1 + TTS/vocoder/models/wavegrad.py | 4 ++++ TTS/vocoder/models/wavernn.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index 9d6ef26f..2728525c 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -20,6 +20,7 @@ class BaseVocoder(BaseModel): def __init__(self, config): super().__init__(config) + self._set_model_args(config) def _set_model_args(self, config: Coqpit): """Setup model args based on the config type. diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 00142c91..9d6e431c 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -339,3 +339,7 @@ class Wavegrad(BaseVocoder): noise_schedule = self.config["train_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index b5b2343a..68f9b2c8 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -631,3 +631,7 @@ class Wavernn(BaseVocoder): def get_criterion(self): # define train functions return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config) From c9972e6f145a3cbf3b79c24b5c18c6d654b86f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:51:58 +0000 Subject: [PATCH 095/214] Make lint --- TTS/tts/datasets/__init__.py | 4 +- TTS/tts/datasets/dataset.py | 157 +++++++------- TTS/tts/utils/synthesis.py | 4 +- TTS/tts/utils/text/characters.py | 191 ++++++++++++------ TTS/tts/utils/text/phonemizers/base.py | 36 ++-- .../utils/text/phonemizers/espeak_wrapper.py | 65 +++--- .../utils/text/phonemizers/gruut_wrapper.py | 3 +- .../text/phonemizers/ja_jp_phonemizer.py | 20 +- .../text/phonemizers/multi_phonemizer.py | 28 +-- .../text/phonemizers/zh_cn_phonemizer.py | 23 ++- TTS/tts/utils/text/punctuation.py | 18 +- TTS/utils/audio.py | 3 +- TTS/utils/synthesizer.py | 1 - TTS/vocoder/models/gan.py | 2 +- 14 files changed, 319 insertions(+), 236 deletions(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index d80e92c9..f0a6ea95 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -111,8 +111,8 @@ def load_tts_samples( meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training - if dataset.meta_file_attn_mask: - meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) + if d.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(d["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins["audio_file"]].strip() meta_data_train_all[idx].update({"alignment_file": attn_file}) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d4a12c07..210de803 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,7 +1,6 @@ import collections import os import random -from multiprocessing import Pool from typing import Dict, List, Union import numpy as np @@ -10,7 +9,6 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -183,7 +181,7 @@ class TTSDataset(Dataset): def get_phonemes(self, idx, text): out_dict = self.phoneme_dataset[idx] assert text == out_dict["text"], f"{text} != {out_dict['text']}" - assert out_dict["token_ids"].size > 0 + assert len(out_dict["token_ids"]) > 0 return out_dict def get_f0(self, idx): @@ -192,7 +190,8 @@ class TTSDataset(Dataset): assert wav_file == out_dict["audio_file"] return out_dict - def get_attn_maks(self, attn_file): + @staticmethod + def get_attn_mask(attn_file): return np.load(attn_file) def get_token_ids(self, idx, text): @@ -207,7 +206,7 @@ class TTSDataset(Dataset): raw_text = item["text"] - wav = np.asarray(self.load_wav(item[]), dtype=np.float32) + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: @@ -262,7 +261,7 @@ class TTSDataset(Dataset): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] - for i, idx in enumerate(idxs): + for idx in idxs: length = lengths[idx] if length < min_len or length > max_len: ignore_idx.append(idx) @@ -277,6 +276,7 @@ class TTSDataset(Dataset): @staticmethod def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 for i in range(len(samples) // batch_group_size): offset = i * batch_group_size end_offset = offset + batch_group_size @@ -319,7 +319,8 @@ class TTSDataset(Dataset): # shuffle batch groups # create batches with similar length items # the larger the `batch_group_size`, the higher the length variety in a batch. - samples = self.create_buckets(samples, self.batch_group_size) + if self.batch_group_size > 0: + samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items self.samples = samples @@ -571,6 +572,7 @@ class PhonemeDataset(Dataset): We use pytorch dataloader because we are lazy. """ + print("[*] Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( @@ -658,16 +660,21 @@ class F0Dataset: return len(self.samples) def precompute(self, num_workers=0): + print("[*] Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False dataloder = torch.utils.data.DataLoader( batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn ) computed_data = [] for batch in dataloder: f0 = batch["f0"] - computed_data.append([f for f in f0]) + computed_data.append(f for f in f0) pbar.update(batch_size) + self.normalize_f0 = normalize_f0 if self.normalize_f0: computed_data = [tensor for batch in computed_data for tensor in batch] # flatten @@ -746,80 +753,80 @@ class F0Dataset: print(f"{indent}| > Number of instances : {len(self.samples)}") -if __name__ == "__main__": - from torch.utils.data import DataLoader +# if __name__ == "__main__": +# from torch.utils.data import DataLoader - from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig - from TTS.tts.datasets import load_tts_samples - from TTS.tts.utils.text.characters import IPAPhonemes - from TTS.tts.utils.text.phonemizers import ESpeak +# from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig +# from TTS.tts.datasets import load_tts_samples +# from TTS.tts.utils.text.characters import IPAPhonemes +# from TTS.tts.utils.text.phonemizers import ESpeak - dataset_config = BaseDatasetConfig( - name="ljspeech", - meta_file_train="metadata.csv", - path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1", - ) - train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) - samples = train_samples + eval_samples +# dataset_config = BaseDatasetConfig( +# name="ljspeech", +# meta_file_train="metadata.csv", +# path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1", +# ) +# train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) +# samples = train_samples + eval_samples - phonemizer = ESpeak(language="en-us") - tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer) - # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests") - # ph_dataset.precompute(num_workers=4) +# phonemizer = ESpeak(language="en-us") +# tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer) +# # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests") +# # ph_dataset.precompute(num_workers=4) - # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn) - # for batch in dataloader: - # print(batch) - # break +# # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn) +# # for batch in dataloader: +# # print(batch) +# # break - audio_config = BaseAudioConfig( - sample_rate=22050, - win_length=1024, - hop_length=256, - num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=45, - mel_fmin=0, - mel_fmax=8000, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, - ) +# audio_config = BaseAudioConfig( +# sample_rate=22050, +# win_length=1024, +# hop_length=256, +# num_mels=80, +# preemphasis=0.0, +# ref_level_db=20, +# log_func="np.log", +# do_trim_silence=True, +# trim_db=45, +# mel_fmin=0, +# mel_fmax=8000, +# spec_gain=1.0, +# signal_norm=False, +# do_amp_to_db_linear=False, +# ) - ap = AudioProcessor.init_from_config(audio_config) +# ap = AudioProcessor.init_from_config(audio_config) - # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4) +# # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4) - # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn) - # for batch in dataloader: - # print(batch) - # breakpoint() - # break +# # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn) +# # for batch in dataloader: +# # print(batch) +# # breakpoint() +# # break - dataset = TTSDataset( - outputs_per_step=1, - compute_linear_spec=False, - samples=samples, - ap=ap, - return_wav=False, - batch_group_size=0, - min_seq_len=0, - max_seq_len=500, - use_noise_augment=False, - verbose=True, - speaker_id_mapping=None, - d_vector_mapping=None, - compute_f0=True, - f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests", - tokenizer=tokenizer, - phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests", - precompute_num_workers=4, - ) +# dataset = TTSDataset( +# outputs_per_step=1, +# compute_linear_spec=False, +# samples=samples, +# ap=ap, +# return_wav=False, +# batch_group_size=0, +# min_seq_len=0, +# max_seq_len=500, +# use_noise_augment=False, +# verbose=True, +# speaker_id_mapping=None, +# d_vector_mapping=None, +# compute_f0=True, +# f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests", +# tokenizer=tokenizer, +# phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests", +# precompute_num_workers=4, +# ) - dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) - for batch in dataloader: - print(batch) - break +# dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) +# for batch in dataloader: +# print(batch) +# break diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 979769a8..65dcc1ad 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -199,10 +199,10 @@ def synthesis( wav = model_outputs.squeeze(0) else: if use_griffin_lim: - wav = inv_spectrogram(model_outputs, ap, CONFIG) + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) # trim silence if do_trim_silence: - wav = trim_silence(wav, ap) + wav = trim_silence(wav, model.ap) return_dict = { "wav": wav, "alignments": alignments, diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 24ce51f1..aae6844f 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,3 +1,8 @@ +from dataclasses import replace + +from TTS.tts.configs.shared_configs import CharactersConfig + + def parse_symbols(): return { "pad": _pad, @@ -29,46 +34,49 @@ _diacrilics = "ɚ˞ɫ" _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -def create_graphemes( - characters=_characters, - punctuations=_punctuations, - pad=_pad, - eos=_eos, - bos=_bos, - blank=_blank, - unique=True, -): # pylint: disable=redefined-outer-name - """Function to create default characters and phonemes""" - # create graphemes - _graphemes = list(characters) - _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes - _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes - _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes - _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes - _graphemes = _graphemes + list(punctuations) - return _graphemes, _phonemes +# def create_graphemes( +# characters=_characters, +# punctuations=_punctuations, +# pad=_pad, +# eos=_eos, +# bos=_bos, +# blank=_blank, +# unique=True, +# ): # pylint: disable=redefined-outer-name +# """Function to create default characters and phonemes""" +# # create graphemes +# = ( +# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) +# ) # this is to keep previous models compatible. +# _graphemes = list(characters) +# _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes +# _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes +# _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes +# _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes +# _graphemes = _graphemes + list(punctuations) +# return _graphemes, _phonemes -def create_phonemes( - phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True -): - # create phonemes - _phonemes = None - _phonemes_sorted = ( - sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) - ) # this is to keep previous models compatible. - _phonemes = list(_phonemes_sorted) - _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes - _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes - _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes - _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes - _phonemes = _phonemes + list(punctuations) - _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) - return _phonemes +# def create_phonemes( +# phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True +# ): +# # create phonemes +# _phonemes = None +# _phonemes_sorted = ( +# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) +# ) # this is to keep previous models compatible. +# _phonemes = list(_phonemes_sorted) +# _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes +# _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes +# _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes +# _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes +# _phonemes = _phonemes + list(punctuations) +# _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) +# return _phonemes -graphemes = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos) -phonemes = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) +# DEF_GRAPHEMES = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos) +# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) class BaseCharacters: @@ -114,7 +122,7 @@ class BaseCharacters: eos: str, bos: str, blank: str, - is_unique: bool = True, + is_unique: bool = False, is_sorted: bool = True, ) -> None: self._characters = characters @@ -202,14 +210,20 @@ class BaseCharacters: _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab self._vocab = _vocab + list(self._punctuations) self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension + } if self.is_unique: + duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} assert ( len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) - ), f" [!] There are duplicate characters in the character set. {set([x for x in self.vocab if self.vocab.count(x) > 1])}" + ), f" [!] There are duplicate characters in the character set. {duplicates}" def char_to_id(self, char: str) -> int: - return self._char_to_id[char] + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e def id_to_char(self, idx: int) -> str: return self._id_to_char[idx] @@ -229,9 +243,23 @@ class BaseCharacters: print(f"{indent}| > Num chars: {self.num_chars}") @staticmethod - def init_from_config(config: "Coqpit"): - return BaseCharacters( - **config.characters if config.characters is not None else {}, + def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument + """Init your character class from a config. + + Implement this method for your subclass. + """ + ... + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=self._eos, + bos=self._bos, + blank=self._blank, + is_unique=self.is_unique, + is_sorted=self.is_sorted, ) @@ -275,31 +303,42 @@ class IPAPhonemes(BaseCharacters): eos: str = _eos, bos: str = _bos, blank: str = _blank, - is_unique: bool = True, + is_unique: bool = False, is_sorted: bool = True, ) -> None: super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) @staticmethod def init_from_config(config: "Coqpit"): + """Init a IPAPhonemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ # band-aid for compatibility with old models if "characters" in config and config.characters is not None: if "phonemes" in config.characters and config.characters.phonemes is not None: config.characters["characters"] = config.characters["phonemes"] - return IPAPhonemes( - characters=config.characters["characters"], - punctuations=config.characters["punctuations"], - pad=config.characters["pad"], - eos=config.characters["eos"], - bos=config.characters["bos"], - blank=config.characters["blank"], - is_unique=config.characters["is_unique"], - is_sorted=config.characters["is_sorted"], - ) - else: - return IPAPhonemes( - **config.characters if config.characters is not None else {}, + return ( + IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, ) + # use character set from config + if config.characters is not None: + return IPAPhonemes(**config.characters), config + # return default character set + characters = IPAPhonemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config class Graphemes(BaseCharacters): @@ -339,24 +378,42 @@ class Graphemes(BaseCharacters): eos: str = _eos, bos: str = _bos, blank: str = _blank, - is_unique: bool = True, + is_unique: bool = False, is_sorted: bool = True, ) -> None: super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) @staticmethod def init_from_config(config: "Coqpit"): - return Graphemes( - **config.characters if config.characters is not None else {}, - ) + """Init a Graphemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + if config.characters is not None: + # band-aid for compatibility with old models + if "phonemes" in config.characters: + return ( + Graphemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + return Graphemes(**config.characters), config + characters = Graphemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config if __name__ == "__main__": gr = Graphemes() ph = IPAPhonemes() - - print(gr.vocab) - print(ph.vocab) - - print(gr.num_chars) - assert "a" == gr.id_to_char(gr.char_to_id("a")) + gr.print_log() + ph.print_log() diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index 249c8bce..08fa8e13 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -1,6 +1,5 @@ import abc -import itertools -from typing import List, Tuple, Union +from typing import List, Tuple from TTS.tts.utils.text.punctuation import Punctuation @@ -8,6 +7,19 @@ from TTS.tts.utils.text.punctuation import Punctuation class BasePhonemizer(abc.ABC): """Base phonemizer class + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + Args: language (str): Language used by the phonemizer. @@ -51,40 +63,30 @@ class BasePhonemizer(abc.ABC): @abc.abstractmethod def name(): """The name of the backend""" + ... @classmethod @abc.abstractmethod def is_available(cls): """Returns True if the backend is installed, False otherwise""" + ... @classmethod @abc.abstractmethod def version(cls): """Return the backend version as a tuple (major, minor, patch)""" + ... + @staticmethod @abc.abstractmethod def supported_languages(): """Return a dict of language codes -> name supported by the backend""" + ... def is_supported_language(self, language): """Returns True if `language` is supported by the backend""" return language in self.supported_languages() - fr""" - Phonemization follows the following steps: - 1. Preprocessing: - - remove empty lines - - remove punctuation - - keep track of punctuation marks - - 2. Phonemization: - - convert text to phonemes - - 3. Postprocessing: - - join phonemes - - restore punctuation marks - """ - @abc.abstractmethod def _phonemize(self, text, separator): """The main phonemization method""" diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index f1d0b6cd..3cccee41 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -28,29 +28,30 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: "1", # UTF8 text encoding ] cmd.extend(args) - logging.debug("espeakng: executing %s" % repr(cmd)) - p = subprocess.Popen( + logging.debug("espeakng: executing %s", repr(cmd)) + + with subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - ) - res = iter(p.stdout.readline, b"") - if not sync: + ) as p: + res = iter(p.stdout.readline, b"") + if not sync: + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + return res + res2 = [] + for line in res: + res2.append(line) p.stdout.close() if p.stderr: p.stderr.close() if p.stdin: p.stdin.close() - return res - res2 = [] - for line in res: - res2.append(line) - p.stdout.close() - if p.stderr: - p.stderr.close() - if p.stdin: - p.stdin.close() - p.wait() + p.wait() return res2 @@ -85,7 +86,24 @@ class ESpeak(BasePhonemizer): def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): if self._ESPEAK_LIB is None: raise Exception("Unknown backend: %s" % backend) + + # band-aid for backwards compatibility + if language == "en": + language = "en-us" + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + if backend is not None: + self.backend = backend + + @property + def backend(self): + return self._ESPEAK_LIB + + @backend.setter + def backend(self, backend): + if backend not in ["espeak", "espeak-ng"]: + raise Exception("Unknown backend: %s" % backend) + self._ESPEAK_LIB = backend def auto_set_espeak_lib(self) -> None: if is_tool("espeak-ng"): @@ -115,24 +133,25 @@ class ESpeak(BasePhonemizer): # espeak and espeak-ng parses `ipa` differently if tie: # use '͡' between phonemes - if _DEF_ESPEAK_LIB == "espeak": + if self.backend == "espeak": args.append("--ipa=1") else: args.append("--ipa=3") else: # split with '_' - if _DEF_ESPEAK_LIB == "espeak": + if self.backend == "espeak": args.append("--ipa=3") else: args.append("--ipa=1") if tie: args.append("--tie=%s" % tie) + args.append('"' + text + '"') # compute phonemes phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): - logging.debug("line: %s" % repr(line)) - phonemes += line.decode("utf8").strip() + logging.debug("line: %s", repr(line)) + phonemes += line.decode("utf8").strip()[2:] # skip two redundant characters return phonemes.replace("_", separator) def _phonemize(self, text, separator=None): @@ -146,7 +165,7 @@ class ESpeak(BasePhonemizer): Dict: Dictionary of language codes. """ if _DEF_ESPEAK_LIB is None: - raise {} + return {} args = ["--voices"] langs = {} count = 0 @@ -157,7 +176,7 @@ class ESpeak(BasePhonemizer): lang_code = cols[1] lang_name = cols[3] langs[lang_code] = lang_name - logging.debug("line: %s" % repr(line)) + logging.debug("line: %s", repr(line)) count += 1 return langs @@ -168,9 +187,9 @@ class ESpeak(BasePhonemizer): str: Version of the used backend. """ args = ["--version"] - for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): + for line in _espeak_exe(self.backend, args, sync=True): version = line.decode("utf8").strip().split()[2] - logging.debug("line: %s" % repr(line)) + logging.debug("line: %s", repr(line)) return version @classmethod diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py index d0aa469e..f3e9c9ab 100644 --- a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -1,5 +1,4 @@ import importlib -from os import stat from typing import List import gruut @@ -55,7 +54,7 @@ class Gruut(BasePhonemizer): def name(): return "gruut" - def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument """Convert input text to phonemes. Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py index 4f93edeb..60b965f9 100644 --- a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -30,7 +30,7 @@ class JA_JP_Phonemizer(BasePhonemizer): language = "ja-jp" - def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) @staticmethod @@ -61,12 +61,12 @@ class JA_JP_Phonemizer(BasePhonemizer): return True -if __name__ == "__main__": - text = "これは、電話をかけるための私の日本語の例のテキストです。" - e = JA_JP_Phonemizer() - print(e.supported_languages()) - print(e.version()) - print(e.language) - print(e.name()) - print(e.is_available()) - print("`" + e.phonemize(text) + "`") +# if __name__ == "__main__": +# text = "これは、電話をかけるための私の日本語の例のテキストです。" +# e = JA_JP_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py index e8b2ce34..e36b0a2a 100644 --- a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -17,7 +17,7 @@ class MultiPhonemizer: lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER language = "multi-lingual" - def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: + def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer) self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) @@ -40,16 +40,16 @@ class MultiPhonemizer: return list(self.lang_to_phonemizer_name.keys()) -if __name__ == "__main__": - texts = { - "tr": "Merhaba, bu Türkçe bit örnek!", - "en-us": "Hello, this is English example!", - "de": "Hallo, das ist ein Deutches Beipiel!", - "zh-cn": "这是中国的例子", - } - phonemes = {} - ph = MultiPhonemizer() - for lang, text in texts.items(): - phoneme = ph.phonemize(text, lang) - phonemes[lang] = phoneme - print(phonemes) +# if __name__ == "__main__": +# texts = { +# "tr": "Merhaba, bu Türkçe bit örnek!", +# "en-us": "Hello, this is English example!", +# "de": "Hallo, das ist ein Deutches Beipiel!", +# "zh-cn": "这是中国的例子", +# } +# phonemes = {} +# ph = MultiPhonemizer() +# for lang, text in texts.items(): +# phoneme = ph.phonemize(text, lang) +# phonemes[lang] = phoneme +# print(phonemes) diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py index e1bd77c7..5a4a5591 100644 --- a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -25,14 +25,15 @@ class ZH_CN_Phonemizer(BasePhonemizer): language = "zh-cn" - def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) @staticmethod def name(): return "zh_cn_phonemizer" - def phonemize_zh_cn(self, text: str, separator: str = "|") -> str: + @staticmethod + def phonemize_zh_cn(text: str, separator: str = "|") -> str: ph = chinese_text_to_phonemes(text, separator) return ph @@ -50,12 +51,12 @@ class ZH_CN_Phonemizer(BasePhonemizer): return True -if __name__ == "__main__": - text = "这是,样本中文。" - e = ZH_CN_Phonemizer() - print(e.supported_languages()) - print(e.version()) - print(e.language) - print(e.name()) - print(e.is_available()) - print("`" + e.phonemize(text) + "`") +# if __name__ == "__main__": +# text = "这是,样本中文。" +# e = ZH_CN_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 414ac253..09087d5f 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -130,7 +130,7 @@ class Punctuation: return cls._restore(text, puncs, 0) @classmethod - def _restore(cls, text, puncs, num): + def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements """Auxiliary method for Punctuation.restore()""" if not puncs: return text @@ -159,14 +159,14 @@ class Punctuation: return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) -if __name__ == "__main__": - punc = Punctuation() - text = "This is. This is, example!" +# if __name__ == "__main__": +# punc = Punctuation() +# text = "This is. This is, example!" - print(punc.strip(text)) +# print(punc.strip(text)) - split_text, puncs = punc.strip_to_restore(text) - print(split_text, " ---- ", puncs) +# split_text, puncs = punc.strip_to_restore(text) +# print(split_text, " ---- ", puncs) - restored_text = punc.restore(split_text, puncs) - print(restored_text) +# restored_text = punc.restore(split_text, puncs) +# print(restored_text) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index bdee8615..bfa0e5e1 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -383,8 +383,7 @@ class AudioProcessor(object): def init_from_config(config: "Coqpit"): if "audio" in config: return AudioProcessor(**config.audio) - else: - return AudioProcessor(**config) + return AudioProcessor(**config) ### setting up the parameters ### def _build_mel_basis( diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2e4f4735..f6a1ae6a 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -13,7 +13,6 @@ from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, trim_silence -from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index e56d1db4..f78d69b8 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -314,7 +314,7 @@ class GAN(BaseVocoder): data_items: List, verbose: bool, num_gpus: int, - rank: int = 0, # pylint: disable=unused-argument + rank: int = None, # pylint: disable=unused-argument ): """Initiate and return the GAN dataloader. From 8649d4fd367da3306a7f9aae3c778115444d89ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:52:45 +0000 Subject: [PATCH 096/214] Allow None pad and blank tokens --- TTS/tts/utils/text/tokenizer.py | 48 ++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 68a1c575..3f416bbb 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -57,8 +57,8 @@ class TTSTokenizer: @characters.setter def characters(self, new_characters): self._characters = new_characters - self.pad_id = self.characters.char_to_id(self.characters.pad) - self.blank_id = self.characters.char_to_id(self.characters.blank) + self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None + self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None def encode(self, text: str) -> List[int]: """Encodes a string of text as a sequence of IDs.""" @@ -82,7 +82,7 @@ class TTSTokenizer: text += self.characters.id_to_char(token_id) return text - def text_to_ids(self, text: str, language: str = None) -> List[int]: + def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument """Converts a string of text to a sequence of token IDs. Args: @@ -137,32 +137,50 @@ class TTSTokenizer: print(f"{indent}| > {char}") @staticmethod - def init_from_config(config: "Coqpit"): + def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): """Init Tokenizer object from config Args: config (Coqpit): Coqpit model config. + characters (BaseCharacters): Defines the model character set. If not set, use the default options based on + the config values. Defaults to None. """ # init cleaners if isinstance(config.text_cleaner, (str, list)): text_cleaner = getattr(cleaners, config.text_cleaner) + # init characters + if characters is None: + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + else: + characters, new_config = characters.init_from_config(config) + + # init phonemizer phonemizer = None if config.use_phonemes: - # init phoneme set - characters = IPAPhonemes().init_from_config(config) phonemizer_kwargs = {"language": config.phoneme_language} - # init phonemizer if "phonemizer" in config and config.phonemizer: phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) else: - phonemizer = get_phonemizer_by_name( - DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs - ) - else: - # init character set - characters = Graphemes().init_from_config(config) - return TTSTokenizer( - config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + try: + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) + except KeyError as e: + raise ValueError( + f"""No phonemizer found for language {config.phoneme_language}. + You may need to install a third party library for this language.""" + ) from e + + return ( + TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ), + new_config, ) From bde68d9f25de40e2d8f4dbf8aa1b0b022863bcdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:53:25 +0000 Subject: [PATCH 097/214] Use the same phonemizer for `en` to `en-us` --- TTS/tts/utils/text/phonemizers/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index 0da5875e..5dc117c4 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -29,6 +29,8 @@ _ = [ESpeak.name()] * len(ESPEAK_LANGS) _new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) DEF_LANG_TO_PHONEMIZER.update(_new_dict) +DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] + def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: """Initiate a phonemizer by name From f802a931a399de8d563facc38c2d618f7cde0b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:54:39 +0000 Subject: [PATCH 098/214] Pass samples to init_from_config in SpeakerManager --- TTS/tts/utils/speakers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index c556db79..ba48f27c 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -319,23 +319,27 @@ class SpeakerManager: raise NotImplementedError @staticmethod - def init_from_config(config: "Coqpit"): + def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": """Initialize a speaker manager from config Args: config (Coqpit): Config object. + samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names. + Defaults to None. Returns: SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True: + if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding: + if samples: + speaker_manager = SpeakerManager(data_items=samples) if config.get("speaker_file", None): speaker_manager = SpeakerManager(speaker_id_file_path=config.speaker_file) if config.get("speakers_file", None): speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) - if hasattr(config, "use_d_vector_file") and config.use_speaker_embedding is True: + if hasattr(config, "use_d_vector_file") and config.use_d_vector_file: if config.get("speakers_file", None): speaker_manager = SpeakerManager(d_vectors_file_path=config.speaker_file) if config.get("d_vector_file", None): From ea965a5683c56a39570b4cc91e86cd2bb9799308 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:55:18 +0000 Subject: [PATCH 099/214] Update VITS for the new API --- TTS/tts/models/vits.py | 213 +++++++++++++++++++++-------------------- 1 file changed, 107 insertions(+), 106 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 1de26913..30dc7ec4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,7 +1,8 @@ import math -from dataclasses import dataclass, field +import random +from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch import torchaudio @@ -10,6 +11,7 @@ from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder @@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler @@ -283,91 +286,79 @@ class Vits(BaseTTS): self.END2END = True self.speaker_manager = speaker_manager self.language_manager = language_manager - if config.__class__.__name__ == "VitsConfig": - # loading from VitsConfig - self.num_chars = self.tokenizer.characters.num_chars - self.config = config - args = self.config.model_args - elif isinstance(config, VitsArgs): - # loading from VitsArgs - self.config = config - args = config - else: - raise ValueError("config must be either a VitsConfig or VitsArgs") self.args = args self.init_multispeaker(config) self.init_multilingual(config) - self.length_scale = args.length_scale - self.noise_scale = args.noise_scale - self.inference_noise_scale = args.inference_noise_scale - self.inference_noise_scale_dp = args.inference_noise_scale_dp - self.noise_scale_dp = args.noise_scale_dp - self.max_inference_len = args.max_inference_len - self.spec_segment_size = args.spec_segment_size + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size self.text_encoder = TextEncoder( - args.num_chars, - args.hidden_channels, - args.hidden_channels, - args.hidden_channels_ffn_text_encoder, - args.num_heads_text_encoder, - args.num_layers_text_encoder, - args.kernel_size_text_encoder, - args.dropout_p_text_encoder, - language_emb_dim=self.embedded_language_dim, + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, ) self.posterior_encoder = PosteriorEncoder( - args.out_channels, - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_posterior_encoder, - dilation_rate=args.dilation_rate_posterior_encoder, - num_layers=args.num_layers_posterior_encoder, + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, cond_channels=self.embedded_speaker_dim, ) self.flow = ResidualCouplingBlocks( - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_flow, - dilation_rate=args.dilation_rate_flow, - num_layers=args.num_layers_flow, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, cond_channels=self.embedded_speaker_dim, ) - if args.use_sdp: + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 192, 3, - args.dropout_p_duration_predictor, + self.args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 256, 3, - args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( - args.hidden_channels, + self.args.hidden_channels, 1, - args.resblock_type_decoder, - args.resblock_dilation_sizes_decoder, - args.resblock_kernel_sizes_decoder, - args.upsample_kernel_sizes_decoder, - args.upsample_initial_channel_decoder, - args.upsample_rates_decoder, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, inference_padding=0, cond_channels=self.embedded_speaker_dim, conv_pre_weight_norm=False, @@ -375,8 +366,8 @@ class Vits(BaseTTS): conv_post_bias=False, ) - if args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) + if self.args.init_discriminator: + self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -883,19 +874,17 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - ap = assets["audio_processor"] - self._log(ap, batch, outputs, "train") + self._log(self.ap, batch, outputs, "train") @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - return self._log(ap, batch, outputs, "eval") + return self._log(self.ap, batch, outputs, "eval") @torch.no_grad() - def test_run(self, ap) -> Tuple[Dict, Dict]: + def test_run(self) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -990,36 +979,6 @@ class Vits(BaseTTS): return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] - @staticmethod - def make_symbols(config): - """Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the - whole training and inference steps.""" - _pad = config.characters["pad"] - _punctuations = config.characters["punctuations"] - _letters = config.characters["characters"] - _letters_ipa = config.characters["phonemes"] - symbols = [_pad] + list(_punctuations) + list(_letters) - if config.use_phonemes: - symbols += list(_letters_ipa) - return symbols - - @staticmethod - def get_characters(config: Coqpit): - if config.characters is not None: - symbols = Vits.make_symbols(config) - else: - from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel - parse_symbols, - phonemes, - symbols, - ) - - config.characters = parse_symbols() - if config.use_phonemes: - symbols = phonemes - num_chars = len(symbols) + getattr(config, "add_blank", False) - return symbols, config, num_chars - def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin @@ -1035,23 +994,65 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "Coqpit"): - """Initialize model from config.""" - - # init characters - if config.use_phonemes: - from TTS.tts.utils.text.characters import IPAPhonemes - - characters = IPAPhonemes().init_from_config(config) - else: - from TTS.tts.utils.text.characters import Graphemes - - characters = Graphemes().init_from_config(config) - config.num_chars = characters.num_chars + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ from TTS.utils.audio import AudioProcessor ap = AudioProcessor.init_from_config(config) - tokenizer = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config) - return Vits(config, ap, tokenizer, speaker_manager) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Vits(new_config, ap, tokenizer, speaker_manager) + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) From d0ec4b91e5cc393fe29431fbe2c6c047fc5d5e4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:55:45 +0000 Subject: [PATCH 100/214] Update Tacotron models --- TTS/tts/models/base_tacotron.py | 22 +++++++++++++++-- TTS/tts/models/tacotron.py | 43 +++++++++++++++++++++++--------- TTS/tts/models/tacotron2.py | 44 +++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 27 deletions(-) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index ca8f3bb9..54939c61 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -9,6 +9,8 @@ from torch import nn from TTS.tts.layers.losses import TacotronLoss from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler @@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" - def __init__(self, config: Coqpit): - super().__init__(config) + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields as class attributes for key in config: @@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS): """Get the model criterion used in training.""" return TacotronLoss(self.config) + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + ############################# # COMMON COMPUTE FUNCTIONS ############################# diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 4e46d252..8341f5bb 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -1,7 +1,8 @@ # coding: utf-8 +from typing import Dict, List, Union + import torch -from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -24,12 +26,15 @@ class Tacotron(BaseTacotron): a multi-speaker model. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): - super().__init__(config) + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - self.speaker_manager = speaker_manager - chars, self.config, _ = self.get_characters(config) - config.num_chars = self.num_chars = len(chars) + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields to `self` # for fewer code change @@ -302,16 +307,30 @@ class Tacotron(BaseTacotron): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index ead3bf2b..d4e665e3 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -1,9 +1,8 @@ # coding: utf-8 -from typing import Dict +from typing import Dict, List, Union import torch -from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron): Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): - super().__init__(config) + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) - self.speaker_manager = speaker_manager - chars, self.config, _ = self.get_characters(config) - config.num_chars = len(chars) self.decoder_output_dim = config.out_channels # pass all config fields to `self` @@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron): self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use """Log training progress.""" - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) From 18f726af6594de93cd44c4c1bb7b9ccc037c3f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:16 +0000 Subject: [PATCH 101/214] Update ForwardTTS --- TTS/tts/models/base_tts.py | 19 +++++++---------- TTS/tts/models/forward_tts.py | 40 ++++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 27231790..59862322 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,6 +1,6 @@ import os import random -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -56,9 +56,10 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: - num_chars = ( - self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars if "characters" in config: self.config.num_chars = num_chars if hasattr(self.config, "model_args"): @@ -237,7 +238,7 @@ class BaseTTS(BaseModel): config: Coqpit, assets: Dict, is_eval: bool, - data_items: List, + samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, @@ -274,7 +275,7 @@ class BaseTTS(BaseModel): compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), - meta_data=data_items, + samples=samples, ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, @@ -283,6 +284,7 @@ class BaseTTS(BaseModel): min_audio_len=config.min_audio_len, max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, @@ -357,8 +359,6 @@ class BaseTTS(BaseModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] - tokenizer = assets["tokenizer"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -370,18 +370,15 @@ class BaseTTS(BaseModel): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, - tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs_dict["outputs"]["model_outputs"], ap, output_fig=False + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment( outputs_dict["outputs"]["alignments"], output_fig=False diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b2c41df5..699f3142 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram @@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS): """ # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) - self.speaker_manager = speaker_manager self.init_multispeaker(config) self.max_duration = self.args.max_duration @@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS): """Enable binary alignment loss when needed""" if trainer.total_steps_done > self.config.binary_align_loss_start_step: self.use_binary_alignment_loss = True + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) From bacf79f4fbb87500a43d9dbf48e466cc4d35b77a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:24 +0000 Subject: [PATCH 102/214] Update AlignTTS --- TTS/tts/models/align_tts.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2fc00b0b..c1e2ffb3 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Dict, List, Union import torch from coqpit import Coqpit @@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -100,11 +102,16 @@ class AlignTTS(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) self.speaker_manager = speaker_manager - self.config = config self.phase = -1 self.length_scale = ( float(config.model_args.length_scale) @@ -112,10 +119,6 @@ class AlignTTS(BaseTTS): else config.model_args.length_scale ) - if not self.config.model_args.num_chars: - _, self.config, num_chars = self.get_characters(config) - self.config.model_args.num_chars = num_chars - self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.embedded_speaker_dim = 0 @@ -382,19 +385,17 @@ class AlignTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -430,3 +431,19 @@ class AlignTTS(BaseTTS): def on_epoch_start(self, trainer): """Set AlignTTS training phase on epoch start.""" self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) From 7c4243fba7738748006ee2ac2e806812616f02a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:31 +0000 Subject: [PATCH 103/214] Update GlowTTS --- TTS/tts/models/glow_tts.py | 48 ++++++++++++++------------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 73680f32..7a48b023 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,5 +1,5 @@ import math -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -50,8 +50,8 @@ class GlowTTS(BaseTTS): def __init__( self, config: GlowTTSConfig, - ap: "AudioProcessor", - tokenizer: "TTSTokenizer", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): @@ -63,7 +63,6 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - self.num_chars = self.tokenizer.characters.num_chars self.decoder_output_dim = config.out_channels # init multi-speaker layers if necessary @@ -429,20 +428,18 @@ class GlowTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: @@ -467,19 +464,16 @@ class GlowTTS(BaseTTS): sen, self.config, "cuda" in str(next(self.parameters()).device), - self.ap, - self.tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs["outputs"]["model_outputs"], ap, output_fig=False + outputs["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return test_figures, test_audios @@ -516,23 +510,17 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: Coqpit): - """Initialize model from config.""" - - # init characters - if config.use_phonemes: - from TTS.tts.utils.text.characters import IPAPhonemes - - characters = IPAPhonemes().init_from_config(config) - else: - from TTS.tts.utils.text.characters import Graphemes - - characters = Graphemes().init_from_config(config) - config.num_chars = characters.num_chars + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ from TTS.utils.audio import AudioProcessor ap = AudioProcessor.init_from_config(config) - tokenizer = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config) - return GlowTTS(config, ap, tokenizer, speaker_manager) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) From 4c5cb44eeb528db360e8a0837b16597fa334c4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:44 +0000 Subject: [PATCH 104/214] Update setup_model --- TTS/tts/models/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index cb1c2e21..d76a3beb 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,12 +1,14 @@ +from typing import Dict, List, Union + from TTS.utils.generic_utils import find_module -def setup_model(config: "Coqpit") -> "BaseTTS": +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) else: MyModel = find_module("TTS.tts.models", config.model.lower()) - model = MyModel.init_from_config(config) + model = MyModel.init_from_config(config, samples) return model From cfaa51fddc19c7ac8b9c4607494fd9270eddd60c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:57:51 +0000 Subject: [PATCH 105/214] Update BaseTTS config --- TTS/tts/configs/shared_configs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 98461bdd..ad3bbe70 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -78,7 +78,7 @@ class CharactersConfig(Coqpit): is_unique (bool): remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old - models trained with character lists with duplicates. + models trained with character lists with duplicates. Defaults to True. is_sorted (bool): Sort the characters in alphabetical order. Defaults to True. @@ -166,6 +166,9 @@ class BaseTTSConfig(BaseTrainingConfig): compute_linear_spec (bool): If True data loader computes and returns linear spectrograms alongside the other data. + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + use_noise_augment (bool): Augment the input audio with random noise. @@ -214,6 +217,7 @@ class BaseTTSConfig(BaseTrainingConfig): phoneme_cache_path: str = None # vocabulary parameters characters: CharactersConfig = None + add_blank: bool = False # training params batch_group_size: int = 0 loss_masking: bool = None @@ -225,8 +229,8 @@ class BaseTTSConfig(BaseTrainingConfig): max_text_len: int = float("inf") compute_f0: bool = False compute_linear_spec: bool = False + precompute_num_workers: int = 0 use_noise_augment: bool = False - add_blank: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer From 38a0b3b6c7f9775dd7ecab071eb7ef5b7c4f3dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:58:08 +0000 Subject: [PATCH 106/214] Update train_tts.py --- TTS/bin/train_tts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 16251fdd..ecc8aaf9 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -90,7 +90,6 @@ def main(): model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, parse_command_line_args=False, ) trainer.fit() From 6d9879bf66817dfe7e02dd24e0f48cf63659a793 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:58:41 +0000 Subject: [PATCH 107/214] Update ljspeech recipes --- recipes/ljspeech/align_tts/train_aligntts.py | 37 ++++++++++------ .../ljspeech/fast_pitch/train_fast_pitch.py | 30 +++++++------ .../ljspeech/fast_speech/train_fast_speech.py | 30 +++++++------ recipes/ljspeech/glow_tts/train_glowtts.py | 3 +- .../speedy_speech/train_speedy_speech.py | 43 +++++++++++-------- .../tacotron2-DCA/train_tacotron_dca.py | 39 +++++++++++------ .../tacotron2-DDC/train_tacotron_ddc.py | 25 +++++++++-- recipes/ljspeech/vits_tts/train_vits.py | 7 +-- recipes/vctk/vits/train_vits.py | 21 ++++++--- 9 files changed, 155 insertions(+), 80 deletions(-) diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 68b67d66..d0187aa8 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,9 +1,11 @@ import os from TTS.trainer import Trainer, TrainingArgs -from TTS.tts.configs.align_tts_config import AlignTTSConfig, BaseDatasetConfig +from TTS.tts.configs.align_tts_config import AlignTTSConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.align_tts import AlignTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -31,23 +33,32 @@ config = AlignTTSConfig( datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = AlignTTS(config) +model = AlignTTS(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 0a4a965b..3a772251 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager @@ -46,9 +47,9 @@ config = FastPitchConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, print_step=50, print_eval=False, mixed_precision=False, @@ -67,23 +68,28 @@ if not config.model_args.use_aligner: f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init the model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index a71da94b..f9f1bc06 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager @@ -45,9 +46,9 @@ config = FastSpeechConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, print_step=50, print_eval=False, mixed_precision=False, @@ -66,23 +67,28 @@ if not config.model_args.use_aligner: f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init the model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 4762a77a..dd450a57 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -52,7 +52,8 @@ ap = AudioProcessor.init_from_config(config) # INITIALIZE THE TOKENIZER # Tokenizer is used to convert text to sequences of token IDs. -tokenizer = TTSTokenizer.init_from_config(config) +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) # LOAD DATA SAMPLES # Each sample is a list of ```[text, audio_file_path, speaker_name]``` diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 6b9683af..468e8a5f 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -38,9 +39,9 @@ config = SpeedySpeechConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, print_step=50, print_eval=False, mixed_precision=False, @@ -50,14 +51,22 @@ config = SpeedySpeechConfig( datasets=[dataset_config], ) -# # compute alignments -# if not config.model_args.use_aligner: -# manager = ModelManager() -# model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") -# # TODO: make compute_attention python callable -# os.system( -# f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" -# ) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init audio processor ap = AudioProcessor(**config.audio.to_dict()) @@ -66,16 +75,14 @@ ap = AudioProcessor(**config.audio.to_dict()) train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index 0a285c3b..a7f037e6 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -6,6 +6,7 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer @@ -60,23 +61,35 @@ config = Tacotron2Config( # This is the config that is saved for the future use datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# init model -model = Tacotron2(config) +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = Tacotron2(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index b452094a..285c416c 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -6,6 +6,7 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer @@ -46,6 +47,7 @@ config = Tacotron2Config( # This is the config that is saved for the future use use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, print_step=25, print_eval=True, mixed_precision=False, @@ -56,11 +58,28 @@ config = Tacotron2Config( # This is the config that is saved for the future use # init audio processor ap = AudioProcessor(**config.audio.to_dict()) -# load training samples +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# init model -model = Tacotron2(config) +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = Tacotron2(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 0588e9d9..79c0db2e 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -33,7 +33,7 @@ audio_config = BaseAudioConfig( config = VitsConfig( audio=audio_config, run_name="vits_ljspeech", - batch_size=48, + batch_size=16, eval_batch_size=16, batch_group_size=5, num_loader_workers=0, @@ -48,7 +48,7 @@ config = VitsConfig( compute_input_seq_cache=True, print_step=25, print_eval=True, - mixed_precision=True, + mixed_precision=False, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], @@ -61,7 +61,8 @@ ap = AudioProcessor.init_from_config(config) # INITIALIZE THE TOKENIZER # Tokenizer is used to convert text to sequences of token IDs. -tokenizer = TTSTokenizer.init_from_config(config) +# config is updated with the default characters if not defined in the config. +tokenizer, config = TTSTokenizer.init_from_config(config) # LOAD DATA SAMPLES # Each sample is a list of ```[text, audio_file_path, speaker_name]``` diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 7eb741c4..2906557d 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -7,6 +7,7 @@ from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -63,10 +64,21 @@ config = VitsConfig( datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# config is updated with the default characters if not defined in the config. +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -76,7 +88,7 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = Vits(config, speaker_manager) +model = Vits(config, ap, tokenizer, speaker_manager) # init the trainer and 🚀 trainer = Trainer( @@ -86,6 +98,5 @@ trainer = Trainer( model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, ) trainer.fit() From b341951b7801f846608dec269637cbb2e55e90e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:58:55 +0000 Subject: [PATCH 108/214] Update loader tests --- tests/data_tests/test_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 712e59e3..3ecd42e1 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -116,8 +116,8 @@ class TestTTSDataset(unittest.TestCase): if self.ap.symmetric_norm: self.assertLessEqual(mel_input.max(), self.ap.max_norm) self.assertGreaterEqual( - mel_input.min(), -self.ap.max_norm - ) # pylint: disable=invalid-unary-operand-type + mel_input.min(), -self.ap.max_norm # pylint: disable=invalid-unary-operand-type + ) self.assertLess(mel_input.min(), 0) else: self.assertLessEqual(mel_input.max(), self.ap.max_norm) From 0a47a7eac00eaaf3769b5b4275e68bfc9f41c4a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:59:11 +0000 Subject: [PATCH 109/214] Update tests --- tests/inference_tests/test_synthesize.py | 12 ++-- tests/text_tests/test_characters.py | 4 +- tests/text_tests/test_phonemizer.py | 85 +++++++++++++++++++++--- tests/text_tests/test_tokenizer.py | 14 ++-- tests/tts_tests/test_glow_tts_train.py | 1 - tests/tts_tests/test_vits_train.py | 1 - 6 files changed, 90 insertions(+), 27 deletions(-) diff --git a/tests/inference_tests/test_synthesize.py b/tests/inference_tests/test_synthesize.py index 635506ab..42b77172 100644 --- a/tests/inference_tests/test_synthesize.py +++ b/tests/inference_tests/test_synthesize.py @@ -19,9 +19,9 @@ def test_synthesize(): f'--text "This is an example." --out_path "{output_path}"' ) - # multi-speaker model - run_cli("tts --model_name tts_models/en/vctk/sc-glow-tts --list_speaker_idxs") - run_cli( - f'tts --model_name tts_models/en/vctk/sc-glow-tts --speaker_idx "p304" ' - f'--text "This is an example." --out_path "{output_path}"' - ) + # multi-speaker SC-Glow model + # run_cli("tts --model_name tts_models/en/vctk/sc-glow-tts --list_speaker_idxs") + # run_cli( + # f'tts --model_name tts_models/en/vctk/sc-glow-tts --speaker_idx "p304" ' + # f'--text "This is an example." --out_path "{output_path}"' + # ) diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index ed84b5b4..3f4086d5 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -2,6 +2,8 @@ import unittest from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, create_graphemes, create_phonemes +# pylint: disable=protected-access + def test_make_symbols(): _ = create_phonemes() @@ -12,7 +14,7 @@ class BaseCharacterTest(unittest.TestCase): def setUp(self): self.characters_empty = BaseCharacters("", "", pad="", eos="", bos="", blank="", is_unique=True, is_sorted=True) - def test_default_character_sets(self): + def test_default_character_sets(self): # pylint: disable=no-self-use """Test initiation of default character sets""" _ = IPAPhonemes() _ = Graphemes() diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py index aa7a5499..512cc195 100644 --- a/tests/text_tests/test_phonemizer.py +++ b/tests/text_tests/test_phonemizer.py @@ -1,20 +1,38 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, create_graphemes, create_phonemes from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer -from TTS.tts.utils.text.tokenizer import TTSTokenizer -EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" +EXAMPLE_TEXTs = [ + "Recent research at Harvard has shown meditating", + "for as little as 8 weeks can actually increase, the grey matter", + "in the parts of the brain responsible", + "for emotional regulation and learning!", +] + + +EXPECTED_ESPEAK_PHONEMES = [ + "ɹ|ˈiː|s|ə|n|t ɹ|ɪ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ɪ|t|ˌeɪ|ɾ|ɪ|ŋ", + "f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|n|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ", + "ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ɪ|s|p|ˈɑː|n|s|ə|b|əl", + "f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|uː|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!", +] + + +EXPECTED_ESPEAKNG_PHONEMES = [ + "ɹ|ˈiː|s|ə|n|t ɹ|ᵻ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ᵻ|t|ˌeɪ|ɾ|ɪ|ŋ", + "f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|ŋ|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ", + "ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ᵻ|s|p|ˈɑː|n|s|ᵻ|b|əl", + "f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|ʊ|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!", +] class TestEspeakPhonemizer(unittest.TestCase): def setUp(self): - self.phonemizer = ESpeak(language="en-us") - self.EXPECTED_PHONEMES = "ɹ|ˈiː|s|ə|n|t ɹ|ɪ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ᵻ|t|ˌeɪ|ɾ|ɪ|ŋ f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|n|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ|ɹ ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ɪ|s|p|ˈɑː|n|s|ə|b|əl f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|uː|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!" + self.phonemizer = ESpeak(language="en-us", backend="espeak") - def test_phonemize(self): - output = self.phonemizer.phonemize(EXAMPLE_TEXT, separator="|") - self.assertEqual(output, self.EXPECTED_PHONEMES) + for text, ph in zip(EXAMPLE_TEXTs, EXPECTED_ESPEAK_PHONEMES): + phonemes = self.phonemizer.phonemize(text) + self.assertEqual(phonemes, ph) # multiple punctuations text = "Be a voice, not an! echo?" @@ -48,14 +66,59 @@ class TestEspeakPhonemizer(unittest.TestCase): self.assertTrue(self.phonemizer.is_available()) +class TestEspeakNgPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ESpeak(language="en-us", backend="espeak-ng") + + for text, ph in zip(EXAMPLE_TEXTs, EXPECTED_ESPEAKNG_PHONEMES): + phonemes = self.phonemizer.phonemize(text) + self.assertEqual(phonemes, ph) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "espeak") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + class TestGruutPhonemizer(unittest.TestCase): def setUp(self): self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False) - self.EXPECTED_PHONEMES = "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ| f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l| f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!" + self.EXPECTED_PHONEMES = ["ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", + "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", + "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", + "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!" + ] def test_phonemize(self): - output = self.phonemizer.phonemize(EXAMPLE_TEXT, separator="|") - self.assertEqual(output, self.EXPECTED_PHONEMES) + for text, ph in zip(EXAMPLE_TEXTs, self.EXPECTED_PHONEMES): + phonemes = self.phonemizer.phonemize(text, separator="|") + self.assertEqual(phonemes, ph) # multiple punctuations text = "Be a voice, not an! echo?" diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 4d3fb0ce..47174518 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -1,6 +1,5 @@ import unittest from dataclasses import dataclass -from os import sep from coqpit import Coqpit @@ -13,7 +12,7 @@ class TestTTSTokenizer(unittest.TestCase): def setUp(self): self.tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) - self.ph = ESpeak("tr") + self.ph = ESpeak("tr", backend="espeak") self.tokenizer_ph = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) def test_encode_decode_graphemes(self): @@ -54,12 +53,12 @@ class TestTTSTokenizer(unittest.TestCase): def test_not_found_characters(self): self.ph = ESpeak("en-us") - self.tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) + tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) self.assertEqual(len(self.tokenizer.not_found_characters), 0) text = "Yolk of one egg beaten light" - ids = self.tokenizer_local.text_to_ids(text) - text_hat = self.tokenizer_local.ids_to_text(ids) - self.assertEqual(self.tokenizer_local.not_found_characters, ["̩"]) + ids = tokenizer_local.text_to_ids(text) + text_hat = tokenizer_local.ids_to_text(ids) + self.assertEqual(tokenizer_local.not_found_characters, ["̩"]) self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt") def test_init_from_config(self): @@ -85,7 +84,8 @@ class TestTTSTokenizer(unittest.TestCase): text_cleaner: str = "phoneme_cleaners" characters = Characters() - tokenizer_ph = TTSTokenizer.init_from_config(TokenizerConfig()) + tokenizer_ph, _ = TTSTokenizer.init_from_config(TokenizerConfig()) + tokenizer_ph.phonemizer.backend = "espeak" text = "Bu bir Örnek." text_ph = "" + self.ph.phonemize(text, separator="") + "" ids = tokenizer_ph.text_to_ids(text) diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index 5a5533b6..e5dc44ee 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -17,7 +17,6 @@ config = GlowTTSConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index 54e655ff..ec9a5915 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -17,7 +17,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, From 4d99fee3e26415ad199e80f0a471f6706aca85c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:59:28 +0000 Subject: [PATCH 110/214] Update spec extractor --- TTS/bin/extract_tts_spectrograms.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 38f576b7..2a2c0b71 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -13,6 +13,7 @@ from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.models import setup_model from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters @@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available() def setup_loader(ap, r, verbose=False): + tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( - r, - c.text_cleaner, + outputs_per_step=r, compute_linear_spec=False, - meta_data=meta_data, + samples=meta_data, + tokenizer=tokenizer, ap=ap, - characters=c.characters if "characters" in c.keys() else None, - add_blank=c["add_blank"] if "add_blank" in c.keys() else False, batch_group_size=0, - min_seq_len=c.min_seq_len, - max_seq_len=c.max_seq_len, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, phoneme_cache_path=c.phoneme_cache_path, - use_phonemes=c.use_phonemes, - phoneme_language=c.phoneme_language, - enable_eos_bos=c.enable_eos_bos_chars, + precompute_num_workers=0, use_noise_augment=False, verbose=verbose, speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None, @@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False): if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) - dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False)) + dataset.preprocess_samples() loader = DataLoader( dataset, @@ -75,8 +75,8 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data["text"] - text_lengths = data["text_lengths"] + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] item_idx = data["item_idxs"] From 17afd7a07cd30711b6c9c543710e449a96d43eb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 13:01:53 +0000 Subject: [PATCH 111/214] Update ljspeech download --- recipes/ljspeech/download_ljspeech.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/ljspeech/download_ljspeech.sh b/recipes/ljspeech/download_ljspeech.sh index 14ef058d..9468988a 100644 --- a/recipes/ljspeech/download_ljspeech.sh +++ b/recipes/ljspeech/download_ljspeech.sh @@ -10,5 +10,5 @@ tar -xjf LJSpeech-1.1.tar.bz2 shuf LJSpeech-1.1/metadata.csv > LJSpeech-1.1/metadata_shuf.csv head -n 12000 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_train.csv tail -n 1100 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_val.csv -mv LJSpeech-1.1 $RUN_DIR/ +mv LJSpeech-1.1 $RUN_DIR/recipes/ljspeech/ rm LJSpeech-1.1.tar.bz2 \ No newline at end of file From 420b92d5aef1d664490485e54f41dff08e48e722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 13:02:02 +0000 Subject: [PATCH 112/214] Update pylintrc --- .pylintrc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 6e9f953e..d5f9c490 100644 --- a/.pylintrc +++ b/.pylintrc @@ -168,7 +168,8 @@ disable=missing-docstring, exception-escape, comprehension-escape, duplicate-code, - not-callable + not-callable, + import-outside-toplevel # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option From 28d98da422b149fd61cf2ba6e37bd89d2cdd52c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 8 Dec 2021 14:45:32 +0000 Subject: [PATCH 113/214] Update VCTK formatter --- TTS/tts/datasets/formatters.py | 37 +++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 5cbc93db..1375757a 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -289,8 +289,10 @@ def brspeech(root_path, meta_file, ignored_speakers=None): return items -def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): - """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" +def vctk(root_path, meta_files=None, wavs_path="wav22", mic="mic2", ignored_speakers=None): + """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" + file_ext = 'flac' + test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: @@ -302,26 +304,33 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): continue with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] - wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) - + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append([text, wav_file, "VCTK_" + speaker_id]) + else: + print(f" [!] wav files don't exist - {wav_file}") return items -def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument +def vctk_old(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + test_speakers = meta_files items = [] - txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) - for text_file in txt_files: - _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) file_id = txt_file.split(".")[0] - # ignore speakers - if isinstance(ignored_speakers, list): - if speaker_id in ignored_speakers: + if isinstance(test_speakers, list): # if is list ignore this speakers ids + if speaker_id in test_speakers: continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append([None, wav_file, "VCTK_" + speaker_id]) - + items.append([text, wav_file, "VCTK_old_" + speaker_id]) return items From 730f7c0df4f21d263dcee5e0e5098c373833e91d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 8 Dec 2021 14:45:57 +0000 Subject: [PATCH 114/214] Add file_ext args to resample.py --- TTS/bin/resample.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py index 3c5ef29c..c9f1166a 100644 --- a/TTS/bin/resample.py +++ b/TTS/bin/resample.py @@ -26,6 +26,7 @@ if __name__ == "__main__": --input_dir /root/LJSpeech-1.1/ --output_sr 22050 --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav --n_jobs 24 """, formatter_class=RawTextHelpFormatter, @@ -55,6 +56,14 @@ if __name__ == "__main__": help="Path of the destination folder. If not defined, the operation is done in place", ) + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + parser.add_argument( "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" ) @@ -67,7 +76,7 @@ if __name__ == "__main__": args.input_dir = args.output_dir print("Resampling the audio files...") - audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True) + audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True) print(f"Found {len(audio_files)} files...") audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) with Pool(processes=args.n_jobs) as p: From df0d58bf09c7189bebbd2e60498e9a3c300981f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 8 Dec 2021 15:15:56 +0000 Subject: [PATCH 115/214] Update VCTK recipes --- TTS/tts/datasets/formatters.py | 2 +- .../speedy_speech/train_speedy_speech.py | 6 --- recipes/vctk/fast_pitch/train_fast_pitch.py | 43 +++++++++++------ recipes/vctk/fast_speech/train_fast_speech.py | 48 +++++++++++-------- recipes/vctk/glow_tts/train_glow_tts.py | 41 +++++++++++----- .../vctk/speedy_speech/train_speedy_speech.py | 44 ++++++++++------- .../vctk/tacotron-DDC/train_tacotron-DDC.py | 42 ++++++++++------ .../vctk/tacotron2-DDC/train_tacotron2-ddc.py | 41 ++++++++++------ recipes/vctk/tacotron2/train_tacotron2.py | 41 ++++++++++------ 9 files changed, 192 insertions(+), 116 deletions(-) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 1375757a..546c3cc3 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -289,7 +289,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): return items -def vctk(root_path, meta_files=None, wavs_path="wav22", mic="mic2", ignored_speakers=None): +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" file_ext = 'flac' test_speakers = meta_files diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 468e8a5f..2f8896c5 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -68,12 +68,6 @@ tokenizer, config = TTSTokenizer.init_from_config(config) # Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) - -# load training samples -train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) - # init model model = ForwardTTS(config, ap, tokenizer) diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index f40587e0..f7a2ef06 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -6,6 +6,7 @@ from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,6 +33,7 @@ config = FastPitchConfig( num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, + precompute_num_workers=4, compute_f0=True, f0_cache_path=os.path.join(output_path, "f0_cache"), run_eval=True, @@ -39,23 +41,35 @@ config = FastPitchConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +79,15 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() + diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index b2988809..853bbb54 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -6,6 +6,7 @@ from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -25,37 +26,48 @@ audio_config = BaseAudioConfig( ) config = FastSpeechConfig( - run_name="fast_pitch_ljspeech", + run_name="fast_speech_vctk", audio=audio_config, batch_size=32, eval_batch_size=16, num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, - compute_f0=True, - f0_cache_path=os.path.join(output_path, "f0_cache"), + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +77,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) -trainer.fit() + +# AND... 3,2,1... 🚀 +trainer.fit() \ No newline at end of file diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 8c9f5388..30050ef5 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -7,6 +7,7 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # set experiment paths @@ -32,6 +33,7 @@ config = GlowTTSConfig( eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, @@ -45,12 +47,27 @@ config = GlowTTSConfig( output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -60,16 +77,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.num_speakers = speaker_manager.num_speakers # init model -model = GlowTTS(config, speaker_manager) +model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) -trainer.fit() + +# AND... 3,2,1... 🚀 +trainer.fit() \ No newline at end of file diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index 81f78d26..85e347fc 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -6,6 +6,7 @@ from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,30 +33,41 @@ config = SpeedySpeechConfig( num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, - compute_f0=True, - f0_cache_path=os.path.join(output_path, "f0_cache"), + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +77,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index b0030f17..7960b34b 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -7,6 +7,7 @@ from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron import Tacotron from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,6 +33,7 @@ config = TacotronConfig( # This is the config that is saved for the future use eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, r=6, @@ -45,18 +47,30 @@ config = TacotronConfig( # This is the config that is saved for the future use print_step=25, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=0, - max_seq_len=44000 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +79,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron(config, speaker_manager) +model = Tacotron(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index 63efb784..bc7951b5 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -7,6 +7,7 @@ from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -44,9 +45,10 @@ config = Tacotron2Config( # This is the config that is saved for the future use print_step=150, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=14800, - max_seq_len=22050 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training @@ -60,10 +62,21 @@ config = Tacotron2Config( # This is the config that is saved for the future use lr=3e-5, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -72,16 +85,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron2(config, speaker_manager) +model = Tacotron2(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 346d650b..82dedade 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -7,6 +7,7 @@ from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -44,9 +45,10 @@ config = Tacotron2Config( # This is the config that is saved for the future use print_step=150, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=14800, - max_seq_len=22050 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training @@ -60,10 +62,21 @@ config = Tacotron2Config( # This is the config that is saved for the future use lr=3e-5, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -72,16 +85,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron2(config, speaker_manager) +model = Tacotron2(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() From c0746f23dfb69a7b306a8b386ec08ecb05c6fea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 8 Dec 2021 15:16:16 +0000 Subject: [PATCH 116/214] Fix `too many open files` --- TTS/tts/datasets/dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 210de803..dee719ef 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -11,6 +11,10 @@ from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy('file_system') + def _parse_sample(item): language_name = None From 29139172fa73b3013afd13b5d7bde68e5a3adaff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 8 Dec 2021 15:18:14 +0000 Subject: [PATCH 117/214] Update recipes README.md --- recipes/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/recipes/README.md b/recipes/README.md index cf3f3de9..21a6727d 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -11,6 +11,12 @@ $ sh ./recipes//download_.sh $ python recipes///train.py ``` +For some datasets you might need to resample the audio files. For example, VCTK dataset can be resampled to 22050Hz as follows. + +```console +python TTS/bin/resample.py --input_dir recipes/vctk/VCTK/wav48_silence_trimmed --output_sr 22050 --output_dir recipes/vctk/VCTK/wav48_silence_trimmed --n_jobs 8 --file_ext flac +``` + If you train a new model using TTS, feel free to share your training to expand the list of recipes. You can also open a new discussion and share your progress with the 🐸 community. \ No newline at end of file From edec27738bf1022badb98e95eb077d837b205f49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:32:31 +0000 Subject: [PATCH 118/214] Delete `use_espeak_phonemes` from tests --- tests/tts_tests/test_vits_d-vectors_train.py | 1 - tests/tts_tests/test_vits_multilingual_speaker_emb_train.py | 3 +-- tests/tts_tests/test_vits_speaker_emb_train.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py index 213669f5..5fd9cbc1 100644 --- a/tests/tts_tests/test_vits_d-vectors_train.py +++ b/tests/tts_tests/test_vits_d-vectors_train.py @@ -16,7 +16,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 78023d26..afa60a1b 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -34,8 +34,7 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, - phoneme_language="en", + phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, test_delay_epochs=-1, diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 8909e8db..1aecc596 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -17,7 +17,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, From 131bc0cfc0c9699439628fa1e517793291ad5eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:33:24 +0000 Subject: [PATCH 119/214] =?UTF-8?q?Fix=20synthesis.py=20=F0=9F=94=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/tts/utils/synthesis.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 65dcc1ad..e2d9c113 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -113,8 +113,6 @@ def synthesis( text, CONFIG, use_cuda, - ap, - tokenizer, speaker_id=None, style_wav=None, use_griffin_lim=False, @@ -166,10 +164,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) # convert text to sequence of token IDs text_inputs = np.asarray( - tokenizer.text_to_ids(text), + model.tokenizer.text_to_ids(text), dtype=np.int32, ) # pass tensors to backend From 5176ae9e53cae7c71a2f19f240bde8752d4b038e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:38:08 +0000 Subject: [PATCH 120/214] Fixes small compat. issues --- TTS/tts/datasets/__init__.py | 4 +-- TTS/tts/datasets/dataset.py | 2 +- TTS/tts/datasets/formatters.py | 2 +- TTS/tts/models/base_tts.py | 8 +++-- TTS/tts/utils/languages.py | 9 ++++++ TTS/tts/utils/speakers.py | 30 ++++++++++++------- recipes/vctk/fast_pitch/train_fast_pitch.py | 1 - recipes/vctk/fast_speech/train_fast_speech.py | 2 +- recipes/vctk/glow_tts/train_glow_tts.py | 2 +- tests/text_tests/test_phonemizer.py | 9 +++--- 10 files changed, 44 insertions(+), 25 deletions(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index f0a6ea95..d80e92c9 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -111,8 +111,8 @@ def load_tts_samples( meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training - if d.meta_file_attn_mask: - meta_data = dict(load_attention_mask_meta_data(d["meta_file_attn_mask"])) + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins["audio_file"]].strip() meta_data_train_all[idx].update({"alignment_file": attn_file}) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index dee719ef..a98afc95 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -13,7 +13,7 @@ from TTS.utils.audio import AudioProcessor # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") def _parse_sample(item): diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 546c3cc3..5168dd06 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -291,7 +291,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" - file_ext = 'flac' + file_ext = "flac" test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 59862322..9a6a56df 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -261,7 +261,7 @@ class BaseTTS(BaseModel): speaker_id_mapping = None d_vector_mapping = None - # setup custom symbols if needed + # setup multi-lingual attributes if hasattr(self, "language_manager"): language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None @@ -290,6 +290,7 @@ class BaseTTS(BaseModel): speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, + language_id_mapping=language_id_mapping, ) # wait all the DDP process to be ready @@ -303,6 +304,7 @@ class BaseTTS(BaseModel): sampler = DistributedSampler(dataset) if num_gpus > 1 else None # Weighted samplers + # TODO: make this DDP amenable assert not ( num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) ), "language_weighted_sampler is not supported with DistributedSampler" @@ -313,10 +315,10 @@ class BaseTTS(BaseModel): if sampler is None: if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.items) + sampler = get_language_weighted_sampler(dataset.samples) elif getattr(config, "use_speaker_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.items) + sampler = get_speaker_weighted_sampler(dataset.samples) loader = DataLoader( dataset, diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index a4f41be5..78b535a0 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -98,6 +98,15 @@ class LanguageManager: """ self._save_json(file_path, self.language_id_mapping) + @staticmethod + def init_from_config(config: Coqpit) -> "LanguageManager": + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + return LanguageManager(config=config) + def _set_file_path(path): """Find the language_ids.json under the given path or the above it. diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index ba48f27c..99d653e6 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -9,7 +9,7 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler -from TTS.config import load_config +from TTS.config import get_from_config_or_model_args_with_default, load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.utils.audio import AudioProcessor @@ -331,19 +331,27 @@ class SpeakerManager: SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding: + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): if samples: speaker_manager = SpeakerManager(data_items=samples) - if config.get("speaker_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=config.speaker_file) - if config.get("speakers_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) - if hasattr(config, "use_d_vector_file") and config.use_d_vector_file: - if config.get("speakers_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=config.speaker_file) - if config.get("d_vector_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) return speaker_manager diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index f7a2ef06..4d9cc10d 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -90,4 +90,3 @@ trainer = Trainer( # AND... 3,2,1... 🚀 trainer.fit() - diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 853bbb54..1dcab982 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -87,4 +87,4 @@ trainer = Trainer( ) # AND... 3,2,1... 🚀 -trainer.fit() \ No newline at end of file +trainer.fit() diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 30050ef5..e35e552d 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -87,4 +87,4 @@ trainer = Trainer( ) # AND... 3,2,1... 🚀 -trainer.fit() \ No newline at end of file +trainer.fit() diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py index 512cc195..9b619f6e 100644 --- a/tests/text_tests/test_phonemizer.py +++ b/tests/text_tests/test_phonemizer.py @@ -109,10 +109,11 @@ class TestEspeakNgPhonemizer(unittest.TestCase): class TestGruutPhonemizer(unittest.TestCase): def setUp(self): self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False) - self.EXPECTED_PHONEMES = ["ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", - "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", - "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", - "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!" + self.EXPECTED_PHONEMES = [ + "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", + "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", + "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", + "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!", ] def test_phonemize(self): From 001da8afc8285bb8a936d331d8f925a3b16f1641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:38:29 +0000 Subject: [PATCH 121/214] Update Vits for the new model API --- TTS/tts/models/vits.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 30dc7ec4..b5551268 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -278,7 +278,12 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, ): super().__init__(config, ap, tokenizer, speaker_manager) @@ -287,8 +292,6 @@ class Vits(BaseTTS): self.speaker_manager = speaker_manager self.language_manager = language_manager - self.args = args - self.init_multispeaker(config) self.init_multilingual(config) @@ -309,6 +312,7 @@ class Vits(BaseTTS): self.args.num_layers_text_encoder, self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -884,7 +888,7 @@ class Vits(BaseTTS): return self._log(self.ap, batch, outputs, "eval") @torch.no_grad() - def test_run(self) -> Tuple[Dict, Dict]: + def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -904,7 +908,7 @@ class Vits(BaseTTS): aux_inputs["text"], self.config, "cuda" in str(next(self.parameters()).device), - ap, + self.ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], @@ -1007,7 +1011,8 @@ class Vits(BaseTTS): ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) - return Vits(new_config, ap, tokenizer, speaker_manager) + language_manager = LanguageManager.init_from_config(config) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) class VitsCharacters(BaseCharacters): From 8e248913d60a6711eea624d192f4664b22da732b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:38:57 +0000 Subject: [PATCH 122/214] Update train_tts for the new API --- TTS/bin/train_tts.py | 37 +------------------------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index ecc8aaf9..c36dc529 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -44,43 +44,8 @@ def main(): # load training samples train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size) - # setup audio processor - ap = AudioProcessor(**config.audio) - - # init speaker manager - if check_config_and_model_args(config, "use_speaker_embedding", True): - speaker_manager = SpeakerManager(data_items=train_samples + eval_samples) - if hasattr(config, "model_args"): - config.model_args.num_speakers = speaker_manager.num_speakers - else: - config.num_speakers = speaker_manager.num_speakers - elif check_config_and_model_args(config, "use_d_vector_file", True): - if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True): - speaker_manager = SpeakerManager( - d_vectors_file_path=config.model_args.d_vector_file, - encoder_model_path=config.model_args.speaker_encoder_model_path, - encoder_config_path=config.model_args.speaker_encoder_config_path, - use_cuda=torch.cuda.is_available(), - ) - else: - speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) - config.num_speakers = speaker_manager.num_speakers - if hasattr(config, "model_args"): - config.model_args.num_speakers = speaker_manager.num_speakers - else: - speaker_manager = None - - if check_config_and_model_args(config, "use_language_embedding", True): - language_manager = LanguageManager(config=config) - if hasattr(config, "model_args"): - config.model_args.num_languages = language_manager.num_languages - else: - config.num_languages = language_manager.num_languages - else: - language_manager = None - # init the model from config - model = setup_model(config, speaker_manager, language_manager) + model = setup_model(config, train_samples + eval_samples) # init the trainer and 🚀 trainer = Trainer( From 235f7d9b026b4e7a0bec09c5e36f81eb019f7420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:35:52 +0000 Subject: [PATCH 123/214] Extend glow_tts model tests --- TTS/tts/models/glow_tts.py | 63 +++++-- tests/tts_tests/test_glow_tts.py | 293 +++++++++++++++++++++++++++---- 2 files changed, 300 insertions(+), 56 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 7a48b023..869adcad 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -40,11 +40,20 @@ class GlowTTS(BaseTTS): Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS(config) - + >>> model = GlowTTS.init_from_config(config, verbose=False) """ def __init__( @@ -98,25 +107,23 @@ class GlowTTS(BaseTTS): def init_multispeaker(self, config: Coqpit): """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding - vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. Args: config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 - # init speaker manager - if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager if self.speaker_manager is not None: self.num_speakers = self.speaker_manager.num_speakers # set ultimate speaker embedding size - if config.use_speaker_embedding or config.use_d_vector_file: + if config.use_d_vector_file: self.embedded_speaker_dim = ( config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) + if self.speaker_manager is not None: + assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") @@ -186,12 +193,33 @@ class GlowTTS(BaseTTS): self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ - Shapes: - - x: :math:`[B, T]` - - x_lenghts::math:`B` - - y: :math:`[B, T, C]` - - y_lengths::math:`B` - - g: :math:`[B, C] or B` + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) @@ -510,17 +538,18 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config) + ap = AudioProcessor.init_from_config(config, verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 82d0ec3b..e97b793a 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -1,11 +1,13 @@ import copy import os import unittest +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch from torch import optim -from tests import get_tests_input_path +from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS @@ -28,36 +30,211 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -class GlowTTSTrainTest(unittest.TestCase): - @staticmethod - def test_train_step(): +class TestGlowTTS(unittest.TestCase): + def _create_inputs(self): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + def _check_parameter_changes(self, model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def test_init_multispeaker(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config) + # speaker embedding with default speaker_embedding_dim + config.use_speaker_embedding = True + config.num_speakers = 5 + config.d_vector_dim = None + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + # use external speaker embeddings with speaker_embedding_dim = 301 + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 301 + model = GlowTTS(config) + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, 301) + # use speaker embedddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_speaker_embedding = True + config.speakers_file = os.path.join(get_tests_data_path(), "ljspeech", "speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + # use external speaker embeddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 256 + config.d_vector_file = os.path.join(get_tests_data_path(), "dummy_speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + + def test_unlock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.unlock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertFalse(f.initialized) + + def test_lock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.lock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertTrue(f.initialized) + + def test_forward(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_d_vector(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + d_vector = torch.rand(8, 256).to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_speaker_id(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): + output_shape = outputs["model_outputs"].shape + self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2]) + self.assertEqual(outputs["logdet"], None) + self.assertEqual(outputs["y_mean"].shape, output_shape) + self.assertEqual(outputs["y_log_scale"].shape, output_shape) + self.assertEqual(outputs["alignments"].shape, output_shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_inference(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_d_vector(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + d_vector = torch.rand(8, 256).to(device) + config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_speaker_ids(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_MAS(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + # inference encoder and decoder with MAS + y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) + y2 = model.decoder_inference(mel_spec, mel_lengths) + assert ( + y2["model_outputs"].shape == y["model_outputs"].shape + ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( + y["model_outputs"].shape, y2["model_outputs"].shape + ) + + def test_train_step(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() criterion = GlowTTSLoss() - # model to train config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) - # reference model to compare model weights model_ref = GlowTTS(config).to(device) - model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) - # pass the state to ref model model_ref.load_state_dict(copy.deepcopy(model.state_dict())) - count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 - optimizer = optim.Adam(model.parameters(), lr=0.001) for _ in range(5): optimizer.zero_grad() @@ -75,40 +252,78 @@ class GlowTTSTrainTest(unittest.TestCase): loss = loss_dict["loss"] loss.backward() optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 + self._check_parameter_changes(model, model_ref) - -class GlowTTSInferenceTest(unittest.TestCase): - @staticmethod - def test_inference(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) - - # create model + def test_train_eval_log(self): + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs() + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = mel_lengths + batch["mel_input"] = mel_spec + batch["d_vectors"] = None + batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS(config).to(device) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log") + criterion = model.get_criterion() + outputs, _ = model.train_step(batch, criterion) + model.train_log(batch, outputs, logger, None, 1) + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + def test_test_run(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False model.eval() - print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) - # inference encoder and decoder with MAS - y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) - y2 = model.decoder_inference(mel_spec, mel_lengths) + def test_get_criterion(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + + config = GlowTTSConfig(num_chars=32, num_speakers=2) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(not hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.c_in_channels == config.d_vector_dim) - assert ( - y2["model_outputs"].shape == y["model_outputs"].shape - ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( - y["model_outputs"].shape, y2["model_outputs"].shape - ) From 50e17097a7219e882f9a19dfa200e93b6c9f69a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:36:09 +0000 Subject: [PATCH 124/214] Add verbose option to AudioProcessor --- TTS/utils/audio.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index bfa0e5e1..4d20f468 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -380,10 +380,10 @@ class AudioProcessor(object): self.symmetric_norm = None @staticmethod - def init_from_config(config: "Coqpit"): + def init_from_config(config: "Coqpit", verbose=True): if "audio" in config: - return AudioProcessor(**config.audio) - return AudioProcessor(**config) + return AudioProcessor(verbose=verbose, **config.audio) + return AudioProcessor(verbose=verbose, **config) ### setting up the parameters ### def _build_mel_basis( From 07b0a80d573bb4da794e1b7d778fe43dbd389d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:36:25 +0000 Subject: [PATCH 125/214] Fix tokenizer init_from_config --- TTS/tts/utils/text/tokenizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 3f416bbb..f84a51ee 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -146,8 +146,9 @@ class TTSTokenizer: the config values. Defaults to None. """ # init cleaners + text_cleaner = None if isinstance(config.text_cleaner, (str, list)): - text_cleaner = getattr(cleaners, config.text_cleaner) + text_cleaner = getattr(config, "text_cleaner") # init characters if characters is None: From 7b49a4aa2bf8116534945b6475135d6444b0a5d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:36:46 +0000 Subject: [PATCH 126/214] Fix glow_tts_config missing field --- TTS/tts/configs/glow_tts_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py index ce8eee6d..f42f3e5a 100644 --- a/TTS/tts/configs/glow_tts_config.py +++ b/TTS/tts/configs/glow_tts_config.py @@ -153,6 +153,7 @@ class GlowTTSConfig(BaseTTSConfig): # multi-speaker settings use_speaker_embedding: bool = False + speakers_file: str = None use_d_vector_file: bool = False d_vector_file: str = False From d0eb3e4ef2ab59d99f4a41c9465a523ae0096012 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:37:02 +0000 Subject: [PATCH 127/214] Add get_tests_data_path --- tests/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index 0a0c3379..8906c8c7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -26,6 +26,11 @@ def get_tests_input_path(): return os.path.join(get_tests_path(), "inputs") +def get_tests_data_path(): + """Returns the path to the test data directory.""" + return os.path.join(get_tests_path(), "data") + + def get_tests_output_path(): """Returns the path to the directory for test outputs.""" return os.path.join(get_tests_path(), "outputs") From 2fe16de8e3338eaf834abc5b626fa19b16bb856b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 14:30:53 +0000 Subject: [PATCH 128/214] Make lint --- TTS/bin/train_tts.py | 7 +----- TTS/tts/datasets/formatters.py | 1 - TTS/tts/models/glow_tts.py | 4 +++- TTS/tts/models/vits.py | 14 ++++++++---- TTS/tts/utils/text/tokenizer.py | 2 +- TTS/utils/synthesizer.py | 9 ++------ tests/tts_tests/test_glow_tts.py | 38 +++++++++++++++++++++++--------- 7 files changed, 45 insertions(+), 30 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index c36dc529..824f0128 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,14 +1,9 @@ import os -import torch - -from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config +from TTS.config import load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio import AudioProcessor def main(): diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 5168dd06..5a38039b 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -292,7 +292,6 @@ def brspeech(root_path, meta_file, ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" file_ext = "flac" - test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 869adcad..23eb48da 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -123,7 +123,9 @@ class GlowTTS(BaseTTS): config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) if self.speaker_manager is not None: - assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager." + assert ( + config.d_vector_dim == self.speaker_manager.d_vector_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b5551268..2ecd1a07 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,5 +1,4 @@ import math -import random from dataclasses import dataclass, field, replace from itertools import chain from typing import Dict, List, Tuple, Union @@ -269,10 +268,20 @@ class Vits(BaseTTS): Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. Examples: + Init only model layers. + >>> from TTS.tts.configs.vits_config import VitsConfig >>> from TTS.tts.models.vits import Vits >>> config = VitsConfig() >>> model = Vits(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits.init_from_config(config) """ # pylint: disable=dangerous-default-value @@ -908,13 +917,10 @@ class Vits(BaseTTS): aux_inputs["text"], self.config, "cuda" in str(next(self.parameters()).device), - self.ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], language_id=aux_inputs["language_id"], - language_name=aux_inputs["language_name"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ).values() diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index f84a51ee..80be368d 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -148,7 +148,7 @@ class TTSTokenizer: # init cleaners text_cleaner = None if isinstance(config.text_cleaner, (str, list)): - text_cleaner = getattr(config, "text_cleaner") + text_cleaner = getattr(cleaners, config.text_cleaner) # init characters if characters is None: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index f6a1ae6a..a1a323e8 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -122,13 +122,9 @@ class Synthesizer(object): speaker_manager = self._init_speaker_encoder(speaker_manager) if language_manager is not None: - self.tts_model = setup_tts_model( - config=self.tts_config, - speaker_manager=speaker_manager, - language_manager=language_manager, - ) + self.tts_model = setup_tts_model(config=self.tts_config) else: - self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) + self.tts_model = setup_tts_model(config=self.tts_config) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -333,7 +329,6 @@ class Synthesizer(object): use_cuda=self.use_cuda, speaker_id=speaker_id, language_id=language_id, - language_name=language_name, style_wav=style_wav, use_griffin_lim=use_gl, d_vector=speaker_embedding, diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index e97b793a..e48977e9 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -1,8 +1,6 @@ import copy import os import unittest -from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch from torch import optim @@ -11,7 +9,9 @@ from tests import get_tests_data_path, get_tests_input_path, get_tests_output_pa from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor +from TTS.utils.logging.tensorboard_logger import TensorboardLogger # pylint: disable=unused-variable @@ -31,7 +31,8 @@ def count_parameters(model): class TestGlowTTS(unittest.TestCase): - def _create_inputs(self): + @staticmethod + def _create_inputs(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 @@ -40,7 +41,8 @@ class TestGlowTTS(unittest.TestCase): speaker_ids = torch.randint(0, 5, (8,)).long().to(device) return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids - def _check_parameter_changes(self, model, model_ref): + @staticmethod + def _check_parameter_changes(model, model_ref): count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( @@ -166,7 +168,7 @@ class TestGlowTTS(unittest.TestCase): def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): output_shape = outputs["model_outputs"].shape - self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2]) + self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2]) self.assertEqual(outputs["logdet"], None) self.assertEqual(outputs["y_mean"].shape, output_shape) self.assertEqual(outputs["y_log_scale"].shape, output_shape) @@ -185,7 +187,12 @@ class TestGlowTTS(unittest.TestCase): def test_inference_with_d_vector(self): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() d_vector = torch.rand(8, 256).to(device) - config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) @@ -268,7 +275,9 @@ class TestGlowTTS(unittest.TestCase): model = GlowTTS.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.train() - logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log") + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log" + ) criterion = model.get_criterion() outputs, _ = model.train_step(batch, criterion) model.train_log(batch, outputs, logger, None, 1) @@ -316,14 +325,23 @@ class TestGlowTTS(unittest.TestCase): self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) - config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")) + config = GlowTTSConfig( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) - config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim) - From 146fbfd7c90045f69eb027f54e9f3292eed9951c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 13 Jan 2022 17:39:06 +0000 Subject: [PATCH 129/214] Extend unittests --- TTS/tts/layers/vits/networks.py | 9 +- TTS/tts/models/vits.py | 46 ++++- tests/tts_tests/test_glow_tts.py | 89 ++++++---- tests/tts_tests/test_vits.py | 285 +++++++++++++++++++++++++++---- 4 files changed, 361 insertions(+), 68 deletions(-) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 7c225344..f97b584f 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -83,6 +83,7 @@ class TextEncoder(nn.Module): - x: :math:`[B, T]` - x_length: :math:`[B]` """ + assert x.shape[0] == x_lengths.shape[0] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] # concat the lang emb in embedding chars @@ -90,7 +91,7 @@ class TextEncoder(nn.Module): x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask @@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` @@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2ecd1a07..4612c02b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -563,6 +563,19 @@ class Vits(BaseTTS): - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) @@ -666,15 +679,33 @@ class Vits(BaseTTS): ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}): """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + Shapes: - x: :math:`[B, T_seq]` - - d_vectors: :math:`[B, C, 1]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` """ sid, g, lid = self._set_cond_input(aux_input) - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: @@ -704,8 +735,9 @@ class Vits(BaseTTS): w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) @@ -1004,7 +1036,7 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config Args: @@ -1014,7 +1046,7 @@ class Vits(BaseTTS): """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config) + ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index e48977e9..305f86b8 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -23,6 +23,7 @@ c = GlowTTSConfig() ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") +BATCH_SIZE = 3 def count_parameters(model): @@ -32,13 +33,13 @@ def count_parameters(model): class TestGlowTTS(unittest.TestCase): @staticmethod - def _create_inputs(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids @staticmethod @@ -104,8 +105,8 @@ class TestGlowTTS(unittest.TestCase): if getattr(f, "set_ddi", False): self.assertTrue(f.initialized) - def test_forward(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) # create model config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) @@ -114,16 +115,20 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_forward_with_d_vector(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - d_vector = torch.rand(8, 256).to(device) + def test_forward(self): + self._test_forward(1) + self._test_forward(3) + + def _test_forward_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -137,16 +142,20 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_forward_with_speaker_id(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + def test_forward_with_d_vector(self): + self._test_forward_with_d_vector(1) + self._test_forward_with_d_vector(3) + + def _test_forward_with_speaker_id(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -159,13 +168,17 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + def test_forward_with_speaker_id(self): + self._test_forward_with_speaker_id(1) + self._test_forward_with_speaker_id(3) + def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): output_shape = outputs["model_outputs"].shape self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2]) @@ -176,17 +189,21 @@ class TestGlowTTS(unittest.TestCase): self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_inference(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def _test_inference(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_d_vector(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - d_vector = torch.rand(8, 256).to(device) + def test_inference(self): + self._test_inference(1) + self._test_inference(3) + + def _test_inference_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) config = GlowTTSConfig( num_chars=32, use_d_vector_file=True, @@ -198,9 +215,13 @@ class TestGlowTTS(unittest.TestCase): outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_speaker_ids(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + def test_inference_with_d_vector(self): + self._test_inference_with_d_vector(1) + self._test_inference_with_d_vector(3) + + def _test_inference_with_speaker_ids(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -211,8 +232,12 @@ class TestGlowTTS(unittest.TestCase): outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_MAS(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def test_inference_with_speaker_ids(self): + self._test_inference_with_speaker_ids(1) + self._test_inference_with_speaker_ids(3) + + def _test_inference_with_MAS(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) # create model config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) @@ -226,8 +251,13 @@ class TestGlowTTS(unittest.TestCase): y["model_outputs"].shape, y2["model_outputs"].shape ) + def test_inference_with_MAS(self): + self._test_inference_with_MAS(1) + self._test_inference_with_MAS(3) + def test_train_step(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) criterion = GlowTTSLoss() # model to train config = GlowTTSConfig(num_chars=32) @@ -263,7 +293,8 @@ class TestGlowTTS(unittest.TestCase): self._check_parameter_changes(model, model_ref) def test_train_eval_log(self): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs() + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) batch = {} batch["text_input"] = input_dummy batch["text_lengths"] = input_lengths diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 4274d947..53e7c09e 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -1,9 +1,11 @@ +import copy import os import unittest +from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch -from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path +from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig @@ -100,35 +102,35 @@ class TestVits(unittest.TestCase): self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len)) self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len)) - def _init_inputs(self, config): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(self, config, batch_size=2): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device) - spec_lengths = torch.randint(20, 30, (8,)).long().to(device) + spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device) + spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) - waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device) + waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device) return input_dummy, input_lengths, spec, spec_lengths, waveform - def _check_forward_outputs(self, config, output_dict, encoder_config=None): + def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2): self.assertEqual( output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) - self.assertEqual(output_dict["alignments"].shape, (8, 128, 30)) + self.assertEqual(output_dict["alignments"].shape, (batch_size, 128, 30)) self.assertEqual(output_dict["alignments"].max(), 1) self.assertEqual(output_dict["alignments"].min(), 0) - self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) self.assertEqual( output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) if encoder_config: - self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) - self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) else: self.assertEqual(output_dict["gt_spk_emb"], None) self.assertEqual(output_dict["syn_spk_emb"], None) @@ -137,7 +139,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -148,7 +150,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -157,16 +159,36 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict) + def test_d_vector_forward(self): + batch_size = 2 + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.train() + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + d_vectors = torch.randn(batch_size, 256).to(device) + output_dict = model.forward( + input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} + ) + self._check_forward_outputs(config, output_dict) + def test_multilingual_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config).to(device) output_dict = model.forward( @@ -182,6 +204,7 @@ class TestVits(unittest.TestCase): def test_secl_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) speaker_encoder_config.model_params["use_torch_spec"] = True @@ -198,9 +221,9 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config, speaker_manager=speaker_manager).to(device) output_dict = model.forward( @@ -213,28 +236,228 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict, speaker_encoder_config) + def _check_inference_outputs(self, config, outputs, input_dummy, batch_size=1): + feat_len = outputs["z"].shape[2] + self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension + self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len)) + self.assertEqual(outputs["z"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["z_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["m_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["logs_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + def test_inference(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multispeaker_inference(self): num_speakers = 10 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) - speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multilingual_inference(self): num_speakers = 10 num_langs = 3 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + model = Vits(config).to(device) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) lang_ids = torch.randint(0, num_langs, (1,)).long().to(device) - model = Vits(config).to(device) _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference( + input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids} + ) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + def test_d_vector_inference(self): + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.eval() + # batch size = 1 + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + d_vectors = torch.randn(1, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy) + # batch size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config) + d_vectors = torch.randn(2, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=2) + + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def _create_batch(self, config, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size) + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = mel_lengths + batch["linear_input"] = mel_spec.transpose(1, 2) + batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device) + batch["d_vectors"] = None + batch["speaker_ids"] = None + batch["language_ids"] = None + return batch + + def test_train_step(self): + # setup the model + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits(config).to(device) + # create a batch + batch = self._create_batch(config, 1) + # model to train + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + model.train() + # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizers = model.get_optimizer() + for _ in range(5): + _, loss_dict = model.train_step(batch, criterions, 0) + loss = loss_dict["loss"] + loss.backward() + optimizers[0].step() + + _, loss_dict = model.train_step(batch, criterions, 1) + loss = loss_dict["loss"] + loss.backward() + optimizers[1].step() + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_eval_log(self): + batch_size = 2 + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + batch = self._create_batch(config, batch_size) + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log" + ) + criterion = model.get_criterion() + criterion = [criterion[0].to(device), criterion[1].to(device)] + outputs = [None] * 2 + outputs[0], _ = model.train_step(batch, criterion, 0) + outputs[1], _ = model.train_step(batch, criterion, 1) + model.train_log(batch, outputs, logger, None, 1) + + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + + def test_test_run(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.eval() + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) + + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) + + def test_get_criterion(self): + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(not hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + )) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + )) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) From 21940952bf895af2e0bd88f907118d9471266601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 13 Jan 2022 17:43:05 +0000 Subject: [PATCH 130/214] Make lint --- tests/tts_tests/test_vits.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 53e7c09e..eaa325b0 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -1,7 +1,6 @@ import copy import os import unittest -from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch @@ -11,6 +10,7 @@ from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.logging.tensorboard_logger import TensorboardLogger LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") @@ -337,7 +337,7 @@ class TestVits(unittest.TestCase): count += 1 def _create_batch(self, config, batch_size): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size) + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size) batch = {} batch["text_input"] = input_dummy batch["text_lengths"] = input_lengths @@ -441,22 +441,26 @@ class TestVits(unittest.TestCase): self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) - config = VitsConfig(model_args=VitsArgs( - num_chars=32, - num_speakers=2, - use_speaker_embedding=True, - speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), - )) + config = VitsConfig( + model_args=VitsArgs( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) + ) model = Vits.init_from_config(config, verbose=False).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) - config = VitsConfig(model_args=VitsArgs( - num_chars=32, - use_d_vector_file=True, - d_vector_dim=256, - d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), - )) + config = VitsConfig( + model_args=VitsArgs( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + ) model = Vits.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) From bc2243bac4ef81b120f417fa979013ddd6ac4f27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 14 Jan 2022 12:10:39 +0000 Subject: [PATCH 131/214] Fix tests --- TTS/bin/find_unique_phonemes.py | 8 +++++--- tests/aux_tests/test_find_unique_phonemes.py | 2 -- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index ad567434..c5501552 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -7,14 +7,16 @@ from tqdm.contrib.concurrent import process_map from TTS.config import load_config from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.text import text2phone +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + + +phonemizer = Gruut(language="en-us") def compute_phonemes(item): try: text = item[0] - language = item[-1] - ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|") + ph = phonemizer.phonemize(text).split("|") except: return [] return list(set(ph)) diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py index fa0abe4b..fa740ba3 100644 --- a/tests/aux_tests/test_find_unique_phonemes.py +++ b/tests/aux_tests/test_find_unique_phonemes.py @@ -39,7 +39,6 @@ class TestFindUniquePhonemes(unittest.TestCase): num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -64,7 +63,6 @@ class TestFindUniquePhonemes(unittest.TestCase): num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, From 47fbddc8d468bbba8b086fe759d37e2e454019a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 14 Jan 2022 12:10:54 +0000 Subject: [PATCH 132/214] Fix docstring --- TTS/tts/datasets/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index d80e92c9..dde85808 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -13,6 +13,7 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. Args: +<<<<<<< HEAD items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. @@ -22,6 +23,9 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). +======= + items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. +>>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 @@ -68,7 +72,7 @@ def load_tts_samples( formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It must take the root_path and the meta_file name and return a list of samples in the format of - `[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as example. Defaults to None. eval_split_max_size (int): From c4c471d61d30ba29d0da164329df8d8d21065b18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 21 Jan 2022 15:27:41 +0000 Subject: [PATCH 133/214] Allow padding for shorter segments --- TTS/tts/utils/helpers.py | 37 ++++++++++++++++++++++++++------- tests/tts_tests/test_helpers.py | 30 +++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index b0a010b0..32513377 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,40 +57,61 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): """Segment each sample in a batch based on the provided segment indices Args: x (torch.tensor): Input tensor. segment_indices (torch.tensor): Segment indices. segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): index_start = segment_indices[i] index_end = index_start + segment_size - segments[i] = x[i, :, index_start:index_end] + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] return segments -def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4): +def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False): """Create random segments based on the input lengths. Args: x (torch.tensor): Input tensor. x_lengths (torch.tensor): Input lengths. segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ + _x_lenghts = x_lengths.clone() B, _, T = x.size() - if x_lengths is None: - x_lengths = T - max_idxs = x_lengths - segment_size + 1 - assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() ret = segment(x, segment_indices, segment_size) return ret, segment_indices diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 6a2f260d..708ecbf5 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -1,6 +1,6 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments def average_over_durations_test(): # pylint: disable=no-self-use @@ -39,6 +39,34 @@ def segment_test(): for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum() + try: + segments = segment(x, segment_ids, segment_size=10) + raise Exception("Should have failed") + except: + pass + + segments = segment(x, segment_ids, segment_size=10, pad_short=True) + for idx, start_indx in enumerate(segment_ids): + assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() + + +def rand_segments_test(): + x = T.rand(2, 3, 4) + x_lens = T.randint(3, 4, (2,)) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) + assert segments.shape == (2, 3, 3) + assert all(seg_idxs >= 0), seg_idxs + try: + segments, _ = rand_segments(x, x_lens, segment_size=5) + raise Exception("Should have failed") + except: + pass + x_lens_back = x_lens.clone() + segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) + assert segments.shape == (2, 3, 5) + assert all(seg_idxs >= 0), seg_idxs + assert all(x_lens_back == x_lens) + def generate_path_test(): durations = T.randint(1, 4, (10, 21)) From ef63c995248fb854d1efae73acfbdcf75666c263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 21 Jan 2022 15:29:06 +0000 Subject: [PATCH 134/214] Implement `start_by_longest` option for TTSDatase --- TTS/tts/configs/shared_configs.py | 5 +++++ TTS/tts/configs/vits_config.py | 13 +------------ TTS/tts/datasets/dataset.py | 10 ++++++++++ TTS/tts/models/base_tts.py | 1 + tests/data_tests/test_loader.py | 18 ++++++++++++++++++ 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index ad3bbe70..5c271f07 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig): use_noise_augment (bool): Augment the input audio with random noise. + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + add_blank (bool): Add blank characters between each other two characters. It improves performance for some models at expense of slower run-time due to the longer input sequence. @@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig): compute_linear_spec: bool = False precompute_num_workers: int = 0 use_noise_augment: bool = False + start_by_longest: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 36c948af..d306552d 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. - sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. - - min_seq_len (int): - Minimum sequnce length to be considered for training. Defaults to `0`. - - max_seq_len (int): - Maximum sequnce length to be considered for training. Defaults to `500000`. - r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. @@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 + aligner_loss_alpha = 1.0 speaker_encoder_loss_alpha: float = 1.0 # data loader params @@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec: bool = True # overrides - sort_by_audio_len: bool = True - min_seq_len: int = 0 - max_seq_len: int = 500000 r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index a98afc95..a1bb23c3 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -56,6 +56,7 @@ class TTSDataset(Dataset): d_vector_mapping: Dict = None, language_id_mapping: Dict = None, use_noise_augment: bool = False, + start_by_longest: bool = False, verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. @@ -109,6 +110,8 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() @@ -130,6 +133,7 @@ class TTSDataset(Dataset): self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest self.verbose = verbose self.rescue_item_idx = 1 @@ -315,6 +319,12 @@ class TTSDataset(Dataset): samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) sorted_idxs = self.sort_by_length(audio_lengths) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) if len(samples) == 0: diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9a6a56df..7cdfa915 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -290,6 +290,7 @@ class BaseTTS(BaseModel): speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, ) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 3ecd42e1..f96154bc 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase): max_text_len=c.max_text_len, min_audio_len=c.min_audio_len, max_audio_len=c.max_audio_len, + start_by_longest=start_by_longest ) dataloader = DataLoader( dataset, @@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase): self.assertGreaterEqual(avg_length, last_length) self.assertTrue(is_items_reordered) + def test_start_by_longest(self): + """Test start_by_longest option. + + Ther first item of the fist batch must be longer than all the other items. + """ + if ok_ljspeech: + dataloader, _ = self._create_dataloader(2, c.r, 0, True) + dataloader.dataset.preprocess_samples() + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + mel_lengths = data["mel_lengths"] + if i == 0: + max_len = mel_lengths[0] + print(mel_lengths) + self.assertTrue(all(max_len >= mel_lengths)) + def test_padding_and_spectrograms(self): def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding From 2829027d8b4b72233d1948525b7fcdacd1fa23e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 21 Jan 2022 15:33:15 +0000 Subject: [PATCH 135/214] Refactor VITS model --- TTS/tts/models/vits.py | 109 +++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 41 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4612c02b..222bbca5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -38,7 +38,7 @@ class VitsArgs(Coqpit): Number of characters in the vocabulary. Defaults to 100. out_channels (int): - Number of output channels. Defaults to 513. + Number of output channels of the decoder. Defaults to 513. spec_segment_size (int): Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. @@ -363,6 +363,8 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, ) + upsample_rate = math.prod(self.args.upsample_rates_decoder) + assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -531,6 +533,54 @@ class Vits(BaseTTS): "language_name": language_name, } + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + def forward( self, x: torch.tensor, @@ -596,54 +646,27 @@ class Vits(BaseTTS): # flow layers z_p = self.flow(z, y_mask, g=g) - # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - with torch.no_grad(): - o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - # duration predictor - attn_durations = attn.sum(3) - g_dp = None - if self.args.condition_dp_on_speaker: - g_dp = g.detach() if self.args.detach_dp_input and g is not None else g - if self.args.use_sdp: - loss_duration = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - attn_durations, - g=g_dp, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = loss_duration / torch.sum(x_mask) - else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask - log_durations = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - g=g_dp, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration + if self.args.use_mas: + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + elif self.args.use_aligner_network: + outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb) + outputs["x_lens"] = x_lengths + outputs["y_lens"] = y_lengths # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, + pad_short=True ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -665,11 +688,11 @@ class Vits(BaseTTS): outputs.update( { "model_outputs": o, - "alignments": attn.squeeze(1), - "z": z, - "z_p": z_p, + "alignments" : attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, + "z": z, + "z_p": z_p, "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, @@ -919,14 +942,18 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - self._log(self.ap, batch, outputs, "train") + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - return self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: From 13482dde1f031a37f6882842f376dbd77afc2082 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:22:35 +0000 Subject: [PATCH 136/214] Update GAN model --- TTS/vocoder/models/gan.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index f78d69b8..7e03e94f 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results class GAN(BaseVocoder): - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, ap: AudioProcessor=None): """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. It also helps mixing and matching different generator and disciminator networks easily. @@ -28,6 +28,7 @@ class GAN(BaseVocoder): Args: config (Coqpit): Model configuration. + ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. Examples: Initializing the GAN model with HifiGAN generator and discriminator. @@ -41,6 +42,7 @@ class GAN(BaseVocoder): self.model_d = setup_discriminator(config) self.train_disc = False # if False, train only the generator. self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + self.ap = ap def forward(self, x: torch.Tensor) -> torch.Tensor: """Run the generator's forward pass. @@ -201,10 +203,9 @@ class GAN(BaseVocoder): self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for training.""" - ap = assets["audio_processor"] - figures, audios = self._log("eval", ap, batch, outputs) + figures, audios = self._log("eval", self.ap, batch, outputs) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: @@ -215,10 +216,9 @@ class GAN(BaseVocoder): self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for evaluation.""" - ap = assets["audio_processor"] - figures, audios = self._log("eval", ap, batch, outputs) + figures, audios = self._log("eval", self.ap, batch, outputs) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, @@ -330,12 +330,11 @@ class GAN(BaseVocoder): Returns: DataLoader: Torch dataloader. """ - ap = assets["audio_processor"] dataset = GANDataset( - ap=ap, + ap=self.ap, items=data_items, seq_len=config.seq_len, - hop_len=ap.hop_length, + hop_len=self.ap.hop_length, pad_short=config.pad_short, conv_pad=config.conv_pad, return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, @@ -363,5 +362,6 @@ class GAN(BaseVocoder): return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)] @staticmethod - def init_from_config(config: Coqpit) -> "GAN": - return GAN(config) + def init_from_config(config: Coqpit, verbose=True) -> "GAN": + ap = AudioProcessor.init_from_config(config, verbose=verbose) + return GAN(config, ap=ap) From 7058fcc3ff741ac1ca888ea3ba2aa54d6e2c6279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:23:07 +0000 Subject: [PATCH 137/214] Take file extension as an argument --- TTS/vocoder/datasets/preprocess.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index d8cc350a..0f69b812 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -33,8 +33,8 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): np.save(quant_path, quant) -def find_wav_files(data_path): - wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) return wav_paths @@ -43,8 +43,9 @@ def find_feat_files(data_path): return feat_paths -def load_wav_data(data_path, eval_split_size): - wav_paths = find_wav_files(data_path) +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." np.random.seed(0) np.random.shuffle(wav_paths) return wav_paths[:eval_split_size], wav_paths[eval_split_size:] From 1445a46e9e38e8720e3a981f5115460510a634d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:25:32 +0000 Subject: [PATCH 138/214] Update synthesizer to use iinit_from_config --- TTS/utils/synthesizer.py | 52 ---------------------------------------- 1 file changed, 52 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a1a323e8..ddc2a6a5 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -110,21 +110,12 @@ class Synthesizer(object): use_cuda (bool): enable/disable CUDA use. """ # pylint: disable=global-statement - self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes self.tts_model = setup_tts_model(config=self.tts_config) - speaker_manager = self._init_speaker_manager() - language_manager = self._init_language_manager() if not self.encoder_checkpoint: self._set_speaker_encoder_paths_from_tts_config() - speaker_manager = self._init_speaker_encoder(speaker_manager) - - if language_manager is not None: - self.tts_model = setup_tts_model(config=self.tts_config) - else: - self.tts_model = setup_tts_model(config=self.tts_config) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -157,49 +148,6 @@ class Synthesizer(object): use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False) return use_d_vector_file - def _init_speaker_manager(self): - """Initialize the SpeakerManager""" - # setup if multi-speaker settings are in the global model config - speaker_manager = None - speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None) - if self._is_use_speaker_embedding(): - if self.tts_speakers_file: - speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) - elif speakers_file: - speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file) - - if self._is_use_d_vector_file(): - d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None) - if self.tts_speakers_file: - speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) - elif d_vector_file: - speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file) - return speaker_manager - - def _init_speaker_encoder(self, speaker_manager): - """Initialize the SpeakerEncoder""" - if self.encoder_checkpoint: - if speaker_manager is None: - speaker_manager = SpeakerManager( - encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config - ) - else: - speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) - return speaker_manager - - def _init_language_manager(self): - """Initialize the LanguageManager""" - # setup if multi-lingual settings are in the global model config - language_manager = None - if check_config_and_model_args(self.tts_config, "use_language_embedding", True): - if self.tts_languages_file: - language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) - elif self.tts_config.get("language_ids_file", None): - language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file) - else: - language_manager = LanguageManager(config=self.tts_config) - return language_manager - def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. From cd5d1497cff1f1616ade65a0962bd72c55f085e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:26:23 +0000 Subject: [PATCH 139/214] Add pitch_fmin pitch_fmax args to the audio --- TTS/utils/audio.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 4d20f468..d0777c11 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -239,6 +239,12 @@ class AudioProcessor(object): mel_fmax (int, optional): maximum filter frequency for computing melspectrograms. Defaults to None. + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20. @@ -300,6 +306,8 @@ class AudioProcessor(object): max_norm=None, mel_fmin=None, mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, spec_gain=20, stft_pad_mode="reflect", clip_norm=True, @@ -333,6 +341,8 @@ class AudioProcessor(object): self.symmetric_norm = symmetric_norm self.mel_fmin = mel_fmin or 0 self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax self.spec_gain = float(spec_gain) self.stft_pad_mode = stft_pad_mode self.max_norm = 1.0 if max_norm is None else float(max_norm) @@ -726,12 +736,12 @@ class AudioProcessor(object): >>> WAV_FILE = filename = librosa.util.example_audio_file() >>> from TTS.config import BaseAudioConfig >>> from TTS.utils.audio import AudioProcessor - >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> conf = BaseAudioConfig(pitch_fmax=8000) >>> ap = AudioProcessor(**conf) >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> pitch = ap.compute_f0(wav) """ - assert self.mel_fmax is not None, " [!] Set `mel_fmax` before caling `compute_f0`." + assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." # align F0 length to the spectrogram length if len(x) % self.hop_length == 0: x = np.pad(x, (0, self.hop_length // 2), mode="reflect") @@ -739,7 +749,7 @@ class AudioProcessor(object): f0, t = pw.dio( x.astype(np.double), fs=self.sample_rate, - f0_ceil=self.mel_fmax, + f0_ceil=self.pitch_fmax, frame_period=1000 * self.hop_length / self.sample_rate, ) f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) From 5169d4eb32407ca0278046aaffc56ca6f9e9ef32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:26:47 +0000 Subject: [PATCH 140/214] Plot pitch over input characters --- TTS/tts/utils/visual.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index de6d95c5..4fd1f19c 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -87,6 +87,39 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False) return fig +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = [c for c in chars] + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output, From bb3746279454b504f7730ff4f8d63bf0d8986c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:27:13 +0000 Subject: [PATCH 141/214] Update language manager --- TTS/tts/utils/languages.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 78b535a0..54ba40b2 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,6 +1,7 @@ import json import os from typing import Dict, List +from TTS.config import check_config_and_model_args import fsspec import numpy as np @@ -105,7 +106,12 @@ class LanguageManager: Args: config (Coqpit): Coqpit config. """ - return LanguageManager(config=config) + language_manager = None + if check_config_and_model_args(config, "use_language_embedding", True): + if config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + language_manager = LanguageManager(config=config) + return language_manager def _set_file_path(path): From 34c4be5e49a8c9bdc15e59494f1be5410b5415e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:28:33 +0000 Subject: [PATCH 142/214] Update forwardtts --- TTS/tts/layers/losses.py | 6 +++++- TTS/tts/models/forward_tts.py | 29 ++++++++++++----------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index d770a536..f4a472ad 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -740,6 +740,7 @@ class ForwardTTSLoss(nn.Module): alignment_logprob=None, alignment_hard=None, alignment_soft=None, + binary_loss_weight=None ): loss = 0 return_dict = {} @@ -772,7 +773,10 @@ class ForwardTTSLoss(nn.Module): if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss - return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 699f3142..bb8640a3 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram @dataclass @@ -186,7 +186,7 @@ class ForwardTTS(BaseTTS): self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner self.use_pitch = self.args.use_pitch - self.use_binary_alignment_loss = False + self.binary_loss_weight = 0.0 self.length_scale = ( float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale @@ -644,8 +644,9 @@ class ForwardTTS(BaseTTS): pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, - alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, - alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight ) # compute duration error durations_pred = outputs["durations"] @@ -672,17 +673,12 @@ class ForwardTTS(BaseTTS): # plot pitch figures if self.args.use_pitch: - pitch = batch["pitch"] - pitch_avg_expanded, _ = self.expand_encoder_outputs( - outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] - ) - pitch = pitch[0, 0].data.cpu().numpy() - # TODO: denormalize before plotting - pitch = abs(pitch) - pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) pitch_figures = { - "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), - "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), } figures.update(pitch_figures) @@ -725,9 +721,8 @@ class ForwardTTS(BaseTTS): return ForwardTTSLoss(self.config) def on_train_step_start(self, trainer): - """Enable binary alignment loss when needed""" - if trainer.total_steps_done > self.config.binary_align_loss_start_step: - self.use_binary_alignment_loss = True + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 @staticmethod def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): From 1932401e8d11115efa8eee0a80fa1b265b17fca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:28:48 +0000 Subject: [PATCH 143/214] Fix dataset preprocessing --- TTS/tts/datasets/dataset.py | 67 ++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index a1bb23c3..62e146e0 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,4 +1,5 @@ import collections +from email.mime import audio import os import random from typing import Dict, List, Union @@ -140,8 +141,6 @@ class TTSDataset(Dataset): self.pitch_computed = False self.tokenizer = tokenizer - self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples) - if self.tokenizer.use_phonemes: self.phoneme_dataset = PhonemeDataset( self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers @@ -253,16 +252,14 @@ class TTSDataset(Dataset): return sample @staticmethod - def compute_lengths(samples): - audio_lengths = [] - text_lengths = [] + def _compute_lengths(samples): + new_samples = [] for item in samples: text, wav_file, *_ = _parse_sample(item) - audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio - text_lengths.append(len(text)) - audio_lengths = np.array(audio_lengths) - text_lengths = np.array(text_lengths) - return audio_lengths, text_lengths + audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + text_lenght = len(text) + new_samples += [item + [audio_length, text_lenght]] + return new_samples @staticmethod def filter_by_length(lengths: List[int], min_len: int, max_len: int): @@ -278,8 +275,9 @@ class TTSDataset(Dataset): return ignore_idx, keep_idx @staticmethod - def sort_by_length(lengths: List[int]): - idxs = np.argsort(lengths) # ascending order + def sort_by_length(samples: List[List]): + audio_lengths = [s[-2] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order return idxs @staticmethod @@ -293,39 +291,38 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples - def select_samples_by_idx(self, idxs): - samples = [] - audio_lengths = [] - text_lengths = [] + def _select_samples_by_idx(self, idxs, samples): + samples_new = [] for idx in idxs: - samples.append(self.samples[idx]) - audio_lengths.append(self.audio_lengths[idx]) - text_lengths.append(self.text_lengths[idx]) - return samples, audio_lengths, text_lengths + samples_new.append(samples[idx]) + return samples_new def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ + samples = self._compute_lengths(self.samples) # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + text_lengths = [i[-1] for i in samples] + audio_lengths = [i[-2] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) audio_ignore_idx, audio_keep_idx = self.filter_by_length( - self.audio_lengths, self.min_audio_len, self.max_audio_len + audio_lengths, self.min_audio_len, self.max_audio_len ) - keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) - samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) + samples = self._select_samples_by_idx(keep_idx, samples) - sorted_idxs = self.sort_by_length(audio_lengths) + sorted_idxs = self.sort_by_length(samples) if self.start_by_longest: longest_idxs = sorted_idxs[-1] sorted_idxs[-1] = sorted_idxs[0] sorted_idxs[0] = longest_idxs - samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) + samples = self._select_samples_by_idx(sorted_idxs, samples) if len(samples) == 0: raise RuntimeError(" [!] No samples left") @@ -337,19 +334,19 @@ class TTSDataset(Dataset): samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items - self.samples = samples - self.audio_lengths = audio_lengths - self.text_lengths = text_lengtsh + audio_lengths = [s[-2] for s in samples] + text_lengths = [s[-1] for s in samples] + self.samples = [s[:-2] for s in samples] if self.verbose: print(" | > Preprocessing samples") - print(" | > Max text length: {}".format(np.max(self.text_lengths))) - print(" | > Min text length: {}".format(np.min(self.text_lengths))) - print(" | > Avg text length: {}".format(np.mean(self.text_lengths))) + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) print(" | ") - print(" | > Max audio length: {}".format(np.max(self.audio_lengths))) - print(" | > Min audio length: {}".format(np.min(self.audio_lengths))) - print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths))) + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size)) From b3ed6ff6b74947dbcaaa41fae5060e18b84cc27d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:29:21 +0000 Subject: [PATCH 144/214] Update FastPitchConfig --- TTS/tts/configs/fast_pitch_config.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 8f063102..de870388 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -89,12 +89,9 @@ class FastPitchConfig(BaseTTSConfig): pitch_loss_alpha (float): Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. - binary_loss_alpha (float): + binary_align_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. - min_seq_len (int): Minimum input sequence length to be used at training. @@ -129,12 +126,12 @@ class FastPitchConfig(BaseTTSConfig): duration_loss_type: str = "mse" use_ssim_loss: bool = True ssim_loss_alpha: float = 1.0 - dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 - pitch_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 - binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 20000 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 From 1f0c8179da4965dcf6f2048cdbcea710aef3875a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 10:40:29 +0000 Subject: [PATCH 145/214] Make style --- TTS/bin/find_unique_phonemes.py | 1 - TTS/config/shared_configs.py | 9 +++++++++ TTS/tts/datasets/dataset.py | 10 ++++------ TTS/tts/layers/losses.py | 6 ++++-- TTS/tts/models/forward_tts.py | 2 +- TTS/tts/models/vits.py | 10 ++++++---- TTS/tts/utils/helpers.py | 12 ++++++++---- TTS/tts/utils/languages.py | 3 ++- TTS/tts/utils/visual.py | 2 +- TTS/utils/synthesizer.py | 4 +--- TTS/vocoder/models/gan.py | 4 ++-- tests/data_tests/test_loader.py | 2 +- tests/tts_tests/test_helpers.py | 6 +++--- 13 files changed, 42 insertions(+), 29 deletions(-) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index c5501552..8fe48b2f 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -9,7 +9,6 @@ from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut - phonemizer = Gruut(language="en-us") diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 217282ad..392f10af 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit): do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```0```. + trim_db (int): Silence threshold used for silence trimming. Defaults to 45. @@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit): spec_gain: int = 20 do_amp_to_db_linear: bool = True do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 0.0 # normalization params signal_norm: bool = True min_level_db: int = -100 diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 62e146e0..499e6b7b 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,5 +1,4 @@ import collections -from email.mime import audio import os import random from typing import Dict, List, Union @@ -256,7 +255,7 @@ class TTSDataset(Dataset): new_samples = [] for item in samples: text, wav_file, *_ = _parse_sample(item) - audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio text_lenght = len(text) new_samples += [item + [audio_length, text_lenght]] return new_samples @@ -291,7 +290,8 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples - def _select_samples_by_idx(self, idxs, samples): + @staticmethod + def _select_samples_by_idx(idxs, samples): samples_new = [] for idx in idxs: samples_new.append(samples[idx]) @@ -307,9 +307,7 @@ class TTSDataset(Dataset): text_lengths = [i[-1] for i in samples] audio_lengths = [i[-2] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) - audio_ignore_idx, audio_keep_idx = self.filter_by_length( - audio_lengths, self.min_audio_len, self.max_audio_len - ) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index f4a472ad..827da751 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module): alignment_logprob=None, alignment_hard=None, alignment_soft=None, - binary_loss_weight=None + binary_loss_weight=None, ): loss = 0 return_dict = {} @@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module): binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: - return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) else: return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index bb8640a3..8d554f76 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS): alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_soft=outputs["alignment_soft"], alignment_hard=outputs["alignment_mas"], - binary_loss_weight=self.binary_loss_weight + binary_loss_weight=self.binary_loss_weight, ) # compute duration error durations_pred = outputs["durations"] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 222bbca5..cb4499fb 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -364,7 +364,9 @@ class Vits(BaseTTS): ) upsample_rate = math.prod(self.args.upsample_rates_decoder) - assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" + assert ( + upsample_rate == self.config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -666,7 +668,7 @@ class Vits(BaseTTS): waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short=True + pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -688,7 +690,7 @@ class Vits(BaseTTS): outputs.update( { "model_outputs": o, - "alignments" : attn.squeeze(1), + "alignments": attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, "z": z, @@ -951,7 +953,7 @@ class Vits(BaseTTS): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, batch, outputs, "eval") logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 32513377..c2e7f561 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -68,7 +68,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ """ # pad the input tensor if it is shorter than the segment size if pad_short and x.shape[-1] < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) segments = torch.zeros_like(x[:, :, :segment_size]) @@ -78,12 +78,14 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ x_i = x[i] if pad_short and index_end > x.size(2): # pad the sample if it is shorter than the segment size - x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) segments[i] = x_i[:, index_start:index_end] return segments -def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False): +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): """Create random segments based on the input lengths. Args: @@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size= _x_lenghts[len_diff < 0] = segment_size len_diff = _x_lenghts - segment_size + 1 else: - assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() ret = segment(x, segment_indices, segment_size) return ret, segment_indices diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 54ba40b2..19708c13 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,7 +1,6 @@ import json import os from typing import Dict, List -from TTS.config import check_config_and_model_args import fsspec import numpy as np @@ -9,6 +8,8 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler +from TTS.config import check_config_and_model_args + class LanguageManager: """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 4fd1f19c..78c12981 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): fig, ax = plt.subplots() x = np.array(range(len(chars))) - my_xticks = [c for c in chars] + my_xticks = chars plt.xticks(x, my_xticks) ax.set_xlabel("characters") diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index ddc2a6a5..6821e975 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,10 +5,8 @@ import numpy as np import pysbd import torch -from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config +from TTS.config import load_config from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 7e03e94f..6978f0e7 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results class GAN(BaseVocoder): - def __init__(self, config: Coqpit, ap: AudioProcessor=None): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. It also helps mixing and matching different generator and disciminator networks easily. @@ -306,7 +306,7 @@ class GAN(BaseVocoder): x, y = batch return {"input": x, "waveform": y} - def get_data_loader( # pylint: disable=no-self-use + def get_data_loader( # pylint: disable=no-self-use, unused-argument self, config: Coqpit, assets: Dict, diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index f96154bc..4d8cc68a 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase): max_text_len=c.max_text_len, min_audio_len=c.min_audio_len, max_audio_len=c.max_audio_len, - start_by_longest=start_by_longest + start_by_longest=start_by_longest, ) dataloader = DataLoader( dataset, diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 708ecbf5..23bb440a 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -1,6 +1,6 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments +from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask def average_over_durations_test(): # pylint: disable=no-self-use @@ -57,12 +57,12 @@ def rand_segments_test(): assert segments.shape == (2, 3, 3) assert all(seg_idxs >= 0), seg_idxs try: - segments, _ = rand_segments(x, x_lens, segment_size=5) + segments, _ = rand_segments(x, x_lens, segment_size=5) raise Exception("Should have failed") except: pass x_lens_back = x_lens.clone() - segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) + segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) assert segments.shape == (2, 3, 5) assert all(seg_idxs >= 0), seg_idxs assert all(x_lens_back == x_lens) From ec4b03c0455b02c2d7f1b1cd0ff871a1ae5cf212 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 10:41:20 +0000 Subject: [PATCH 146/214] Update AnalyzeDataset notebook --- .../dataset_analysis/AnalyzeDataset.ipynb | 76 ++++++++++++------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/notebooks/dataset_analysis/AnalyzeDataset.ipynb b/notebooks/dataset_analysis/AnalyzeDataset.ipynb index c2aabbf9..e08f3ab3 100644 --- a/notebooks/dataset_analysis/AnalyzeDataset.ipynb +++ b/notebooks/dataset_analysis/AnalyzeDataset.ipynb @@ -8,7 +8,7 @@ }, "outputs": [], "source": [ - "TTS_PATH = \"/home/erogol/projects/\"" + "# TTS_PATH = \"/home/erogol/projects/\"" ] }, { @@ -21,7 +21,6 @@ "source": [ "import os\n", "import sys\n", - "sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n", "import librosa\n", "import numpy as np\n", "import pandas as pd\n", @@ -30,6 +29,8 @@ "from multiprocessing import Pool\n", "from matplotlib import pylab as plt\n", "from collections import Counter\n", + "from TTS.config.shared_configs import BaseDatasetConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", "from TTS.tts.datasets.formatters import *\n", "%matplotlib inline" ] @@ -42,22 +43,29 @@ }, "outputs": [], "source": [ - "DATA_PATH = \"/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/\"\n", - "META_DATA = [\"kleinzaches/metadata.csv\",\n", - " \"spiegel_kaetzchen/metadata.csv\",\n", - " \"herrnarnesschatz/metadata.csv\",\n", - " \"maedchen_von_moorhof/metadata.csv\",\n", - " \"koenigsgaukler/metadata.csv\",\n", - " \"altehous/metadata.csv\",\n", - " \"odysseus/metadata.csv\",\n", - " \"undine/metadata.csv\",\n", - " \"reise_tilsit/metadata.csv\",\n", - " \"schmied_seines_glueckes/metadata.csv\",\n", - " \"kammmacher/metadata.csv\",\n", - " \"unterm_birnbaum/metadata.csv\",\n", - " \"liebesbriefe/metadata.csv\",\n", - " \"sandmann/metadata.csv\"]\n", - "NUM_PROC = 8" + "NUM_PROC = 8\n", + "DATASET_CONFIG = BaseDatasetConfig(\n", + " name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/home/ubuntu/TTS/depot/data/male_dataset1_44k/\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n", + " txt_file = os.path.join(root_path, meta_file)\n", + " items = []\n", + " speaker_name = \"maledataset1\"\n", + " with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n", + " for line in ttf:\n", + " cols = line.split(\"|\")\n", + " wav_file = os.path.join(root_path, \"wavs\", cols[0])\n", + " text = cols[1]\n", + " items.append([text, wav_file, speaker_name])\n", + " return items" ] }, { @@ -69,8 +77,10 @@ "outputs": [], "source": [ "# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n", - "items = mailabs(DATA_PATH, META_DATA)\n", - "print(\" > Number of audio files: {}\".format(len(items)))" + "train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n", + "items = train_samples + eval_samples\n", + "print(\" > Number of audio files: {}\".format(len(items)))\n", + "print(items[1])" ] }, { @@ -103,6 +113,15 @@ "print([item for item, count in c.items() if count > 1])" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "item" + ] + }, { "cell_type": "code", "execution_count": null, @@ -112,11 +131,9 @@ "outputs": [], "source": [ "def load_item(item):\n", - " file_name = item[1].strip()\n", " text = item[0].strip()\n", - " audio = librosa.load(file_name, sr=None)\n", - " sr = audio[1]\n", - " audio = audio[0]\n", + " file_name = item[1].strip()\n", + " audio, sr = librosa.load(file_name, sr=None)\n", " audio_len = len(audio) / sr\n", " text_len = len(text)\n", " return file_name, text, text_len, audio, audio_len\n", @@ -374,11 +391,18 @@ "# fequency bar plot - it takes time!!\n", "w_count_df.plot.bar()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -392,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.1" } }, "nbformat": 4, From d5c0e17548e0a846dc0fb77653965274ae58d880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 Jan 2022 10:20:07 +0100 Subject: [PATCH 147/214] Load right char class dynamically --- TTS/tts/utils/text/tokenizer.py | 21 ++++++++++++++++----- TTS/utils/generic_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 80be368d..bdaf8ea6 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.utils.generic_utils import get_import_path, import_class class TTSTokenizer: @@ -152,15 +153,25 @@ class TTSTokenizer: # init characters if characters is None: - if config.use_phonemes: - # init phoneme set - characters, new_config = IPAPhonemes().init_from_config(config) + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config else: - # init character set - characters, new_config = Graphemes().init_from_config(config) + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + else: characters, new_config = characters.init_from_config(config) + # set characters class + new_config.characters.characters_class = get_import_path(characters) + # init phonemizer phonemizer = None if config.use_phonemes: diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 6504cca6..69609bcb 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -95,6 +95,33 @@ def find_module(module_path: str, module_name: str) -> object: return getattr(module, class_name) +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel From f70e4bb8c62265e91aa6e724bc255f92e04db1c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 Jan 2022 10:22:12 +0100 Subject: [PATCH 148/214] Add new speakers to the vits model --- TTS/tts/configs/shared_configs.py | 5 +++++ TTS/tts/models/vits.py | 17 ++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 5c271f07..96cf0427 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -53,6 +53,10 @@ class CharactersConfig(Coqpit): """Defines arguments for the `BaseCharacters` and its subclasses. Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + pad (str): characters in place of empty padding. Defaults to None. @@ -84,6 +88,7 @@ class CharactersConfig(Coqpit): Sort the characters in alphabetical order. Defaults to True. """ + characters_class: str = None pad: str = None eos: str = None bos: str = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cb4499fb..a69b02ba 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -649,12 +649,7 @@ class Vits(BaseTTS): z_p = self.flow(z, y_mask, g=g) # duration predictor - if self.args.use_mas: - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) - elif self.args.use_aligner_network: - outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb) - outputs["x_lens"] = x_lengths - outputs["y_lens"] = y_lengths + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -1059,7 +1054,15 @@ class Vits(BaseTTS): # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} - self.load_state_dict(state["model"]) + # handle fine-tuning from a checkpoint with additional speakers + if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + print(" > Loading checkpoint with additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.zeros(1, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + + self.load_state_dict(state["model"], strict=False) if eval: self.eval() assert not self.training From 846e0e4284d1a00dee25c5ed85c74830a95ec28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 Jan 2022 10:23:52 +0100 Subject: [PATCH 149/214] Fix VCTK VITS recipe --- recipes/vctk/vits/train_vits.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 2906557d..caf1caa1 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -57,9 +57,7 @@ config = VitsConfig( print_step=25, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=32 * 256 * 4, - max_seq_len=1500000, + max_text_len= 325, # change this if you have a larger VRAM than 16GB output_path=output_path, datasets=[dataset_config], ) From 38314194e7ec1436987d57c8865e1af4fdab6dea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 Jan 2022 13:50:58 +0100 Subject: [PATCH 150/214] Set `drop_last` --- TTS/tts/models/base_tts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7cdfa915..0eb2b5f3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -324,9 +324,9 @@ class BaseTTS(BaseModel): loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, + shuffle=False, # shuffle is done in the dataset. collate_fn=dataset.collate_fn, - drop_last=False, + drop_last=True, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, From a013566d1516c2f6820f101ac066f1278122af0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:35:52 +0100 Subject: [PATCH 151/214] Delete trainer related code --- TTS/trainer.py | 1199 ----------------------- TTS/utils/logging/__init__.py | 24 - TTS/utils/logging/console_logger.py | 105 -- TTS/utils/logging/tensorboard_logger.py | 79 -- TTS/utils/logging/wandb_logger.py | 111 --- 5 files changed, 1518 deletions(-) delete mode 100644 TTS/trainer.py delete mode 100644 TTS/utils/logging/__init__.py delete mode 100644 TTS/utils/logging/console_logger.py delete mode 100644 TTS/utils/logging/tensorboard_logger.py delete mode 100644 TTS/utils/logging/wandb_logger.py diff --git a/TTS/trainer.py b/TTS/trainer.py deleted file mode 100644 index 7bffb386..00000000 --- a/TTS/trainer.py +++ /dev/null @@ -1,1199 +0,0 @@ -# -*- coding: utf-8 -*- - -import importlib -import multiprocessing -import os -import platform -import sys -import time -import traceback -from argparse import Namespace -from dataclasses import dataclass, field -from inspect import signature -from typing import Callable, Dict, List, Tuple, Union - -import torch -import torch.distributed as dist -from coqpit import Coqpit -from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader - -from TTS.utils.callbacks import TrainerCallback -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import ( - KeepAverage, - count_parameters, - get_experiment_folder_path, - get_git_branch, - remove_experiment_folder, - set_init_dict, - to_cuda, -) -from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint -from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger -from TTS.utils.trainer_utils import ( - get_last_checkpoint, - get_optimizer, - get_scheduler, - is_apex_available, - setup_torch_training_env, -) - -multiprocessing.set_start_method("fork") - -if platform.system() != "Windows": - # https://github.com/pytorch/pytorch/issues/973 - import resource - - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) - - -if is_apex_available(): - from apex import amp - - -@dataclass -class TrainingArgs(Coqpit): - """Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and - set the values for distributed training.""" - - continue_path: str = field( - default="", - metadata={ - "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." - }, - ) - restore_path: str = field( - default="", - metadata={ - "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." - }, - ) - best_path: str = field( - default="", - metadata={ - "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" - }, - ) - skip_train_epoch: bool = field( - default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."} - ) - config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) - rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) - group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) - use_ddp: bool = field( - default=False, - metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, - ) - - -class Trainer: - def __init__( # pylint: disable=dangerous-default-value - self, - args: Union[Coqpit, Namespace], - config: Coqpit, - output_path: str, - c_logger: ConsoleLogger = None, - dashboard_logger: Union[TensorboardLogger, WandbLogger] = None, - model: nn.Module = None, - get_model: Callable = None, - get_data_samples: Callable = None, - train_samples: List = None, - eval_samples: List = None, - cudnn_benchmark: bool = False, - training_assets: Dict = {}, - parse_command_line_args: bool = True, - ) -> None: - """Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models - or easily be customized. - - Notes: - - Supports Automatic Mixed Precision training. If `Apex` is availabe, it automatically picks that, else - it uses PyTorch's native `amp` module. `Apex` may provide more stable training in some cases. - - Args: - - args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainingArgs` - config object. - - config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating - and testing the model. - - output_path (str): Path to the output training folder. All the files are saved under thi path. - - c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default - console logger is used. Defaults to None. - - dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used. - Defaults to None. - - model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` - initializes a model from the provided config. Defaults to None. - - get_model (Callable): - A function that returns a model. It is used to initialize the model when `model` is not provided. - It either takes the config as the only argument or does not take any argument. - Defaults to None - - get_data_samples (Callable): - A function that returns a list of training and evaluation samples. Used if `train_samples` and - `eval_samples` are None. Defaults to None. - - train_samples (List): - A list of training samples used by the model's `get_data_loader` to init the `dataset` and the - `data_loader`. Defaults to None. - - eval_samples (List): - A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the - `data_loader`. Defaults to None. - - cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input - length is changing batch to batch along the training. - - training_assets (Dict): - A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()``` - during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}. - - parse_command_line_args (bool): - If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it - to false if you parse the arguments yourself. Defaults to True. - - Examples: - - Running trainer with HifiGAN model. - - >>> args = TrainingArgs(...) - >>> config = HifiganConfig(...) - >>> model = GANModel(config) - >>> ap = AudioProcessor(**config.audio) - >>> assets = {"audio_processor": ap} - >>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets) - >>> trainer.fit() - - TODO: - - Wrap model for not calling .module in DDP. - - Accumulate gradients b/w batches. - - Deepspeed integration - - Profiler integration. - - Overfitting to a batch. - - TPU training - - NOTE: Consider moving `training_assets` to the model implementation. - """ - - if parse_command_line_args: - # parse command-line arguments for TrainerArgs() - args, coqpit_overrides = self.parse_argv(args) - - # get ready for training and parse command-line arguments for the model config - config = self.init_training(args, coqpit_overrides, config) - - # set the output path - if args.continue_path: - # use the same path as the continuing run - output_path = args.continue_path - else: - # override the output path if it is provided - output_path = config.output_path if output_path is None else output_path - # create a new output folder name - output_path = get_experiment_folder_path(config.output_path, config.run_name) - os.makedirs(output_path, exist_ok=True) - - # copy training assets to the output folder - copy_model_files(config, output_path) - - # init class members - self.args = args - self.config = config - self.output_path = output_path - self.config.output_log_path = output_path - self.training_assets = training_assets - - # setup logging - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") - self._setup_logger_config(log_file) - time.sleep(1.0) # wait for the logger to be ready - - # set and initialize Pytorch runtime - self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp) - - # init loggers - self.c_logger = ConsoleLogger() if c_logger is None else c_logger - self.dashboard_logger = dashboard_logger - - # only allow dashboard logging for the main process in DDP mode - if self.dashboard_logger is None and args.rank == 0: - self.dashboard_logger = init_dashboard_logger(config) - - if not self.config.log_model_step: - self.config.log_model_step = self.config.save_step - - self.total_steps_done = 0 - self.epochs_done = 0 - self.restore_step = 0 - self.best_loss = float("inf") - self.train_loader = None - self.eval_loader = None - - self.keep_avg_train = None - self.keep_avg_eval = None - - self.use_apex = self._is_apex_available() - self.use_amp_scaler = self.config.mixed_precision and self.use_cuda - - # load data samples - if train_samples is None and get_data_samples is None: - raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.") - if train_samples is not None: - self.train_samples = train_samples - self.eval_samples = eval_samples - else: - self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples) - - # init TTS model - if model is None and get_model is None: - raise ValueError("[!] `model` and `get_model` cannot both be None.") - if model is not None: - self.model = model - else: - self.run_get_model(self.config, get_model) - - # setup criterion - self.criterion = self.get_criterion(self.model) - - # DISTRUBUTED - if self.num_gpus > 1: - init_distributed( - args.rank, - self.num_gpus, - args.group_id, - self.config.distributed_backend, - self.config.distributed_url, - ) - - if self.use_cuda: - self.model.cuda() - if isinstance(self.criterion, list): - self.criterion = [x.cuda() for x in self.criterion] - else: - self.criterion.cuda() - - # setup optimizer - self.optimizer = self.get_optimizer(self.model, self.config) - - # CALLBACK - self.callbacks = TrainerCallback() - self.callbacks.on_init_start(self) - - # init AMP - if self.use_amp_scaler: - if self.use_apex: - self.scaler = None - self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") - # if isinstance(self.optimizer, list): - # self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) - # else: - self.scaler = torch.cuda.amp.GradScaler() - else: - self.scaler = None - - if self.args.restore_path: - self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( - self.config, args.restore_path, self.model, self.optimizer, self.scaler - ) - - # setup scheduler - self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) - - if self.scheduler is not None: - if self.args.continue_path: - if isinstance(self.scheduler, list): - for scheduler in self.scheduler: - if scheduler is not None: - scheduler.last_epoch = self.restore_step - else: - self.scheduler.last_epoch = self.restore_step - - # DISTRIBUTED - if self.num_gpus > 1: - self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) - - # count model size - num_params = count_parameters(self.model) - print("\n > Model has {} parameters".format(num_params)) - - self.callbacks.on_init_end(self) - - @staticmethod - def parse_argv(args: Union[Coqpit, List]): - """Parse command line arguments to init or override `TrainingArgs()`.""" - if isinstance(args, Coqpit): - parser = args.init_argparse(arg_prefix="") - else: - train_config = TrainingArgs() - parser = train_config.init_argparse(arg_prefix="") - training_args, coqpit_overrides = parser.parse_known_args() - args.parse_args(training_args) - return args, coqpit_overrides - - def init_training( - self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None - ): # pylint: disable=no-self-use - """Initialize training and update model configs from command line arguments. - - Args: - args (argparse.Namespace or dict like): Parsed input arguments. - config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments. - config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. - - Returns: - c (TTS.utils.io.AttrDict): Config paramaters. - """ - # set arguments for continuing training - if args.continue_path: - experiment_path = args.continue_path - args.config_path = os.path.join(args.continue_path, "config.json") - args.restore_path, best_model = get_last_checkpoint(args.continue_path) - if not args.best_path: - args.best_path = best_model - - # override config values from command-line args - # TODO: Maybe it is better to do it outside - if len(coqpit_overrides) > 0: - config.parse_known_args(coqpit_overrides, arg_prefix="coqpit", relaxed_parser=True) - experiment_path = args.continue_path - - # update the config.json fields and copy it to the output folder - if args.rank == 0: - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(config, experiment_path, new_fields) - return config - - @staticmethod - def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module: - """Run the `get_model` function and return the model. - - Args: - config (Coqpit): Model config. - - Returns: - nn.Module: initialized model. - """ - if len(signature(get_model).sig.parameters) == 1: - model = get_model(config) - else: - model = get_model() - return model - - @staticmethod - def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module: - if callable(get_data_samples): - if len(signature(get_data_samples).sig.parameters) == 1: - train_samples, eval_samples = get_data_samples(config) - else: - train_samples, eval_samples = get_data_samples() - return train_samples, eval_samples - return None, None - - def restore_model( - self, - config: Coqpit, - restore_path: str, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: torch.cuda.amp.GradScaler = None, - ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: - """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats. - - Args: - config (Coqpit): Model config. - restore_path (str): Path to the restored training run. - model (nn.Module): Model to restored. - optimizer (torch.optim.Optimizer): Optimizer to restore. - scaler (torch.cuda.amp.GradScaler, optional): AMP scaler to restore. Defaults to None. - - Returns: - Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: [description] - """ - - def _restore_list_objs(states, obj): - if isinstance(obj, list): - for idx, state in enumerate(states): - obj[idx].load_state_dict(state) - else: - obj.load_state_dict(states) - return obj - - print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = load_fsspec(restore_path, map_location="cpu") - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) - if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: - print(" > Restoring Scaler...") - scaler = _restore_list_objs(checkpoint["scaler"], scaler) - except (KeyError, RuntimeError, ValueError): - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - if isinstance(self.optimizer, list): - for idx, optim in enumerate(optimizer): - for group in optim.param_groups: - group["lr"] = self.get_lr(model, config)[idx] - else: - for group in optimizer.param_groups: - group["lr"] = self.get_lr(model, config) - print( - " > Model restored from step %d" % checkpoint["step"], - ) - restore_step = checkpoint["step"] - torch.cuda.empty_cache() - return model, optimizer, scaler, restore_step - - ######################### - # DATA LOADING FUNCTIONS - ######################### - - def _get_loader( - self, - model: nn.Module, - config: Coqpit, - assets: Dict, - is_eval: bool, - data_items: List, - verbose: bool, - num_gpus: int, - ) -> DataLoader: - if num_gpus > 1: - if hasattr(model.module, "get_data_loader"): - loader = model.module.get_data_loader( - config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank - ) - else: - if hasattr(model, "get_data_loader"): - loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus) - return loader - - def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: - """Initialize and return a training data loader. - - Args: - ap (AudioProcessor): Audio processor. - data_items (List): Data samples used for training. - verbose (bool): enable/disable printing loader stats at initialization. - - Returns: - DataLoader: Initialized training data loader. - """ - return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus) - - def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: - return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus) - - def format_batch(self, batch: List) -> Dict: - """Format the dataloader output and return a batch. - - Args: - batch (List): Batch returned by the dataloader. - - Returns: - Dict: Formatted batch. - """ - if self.num_gpus > 1: - batch = self.model.module.format_batch(batch) - else: - batch = self.model.format_batch(batch) - if self.use_cuda: - for k, v in batch.items(): - batch[k] = to_cuda(v) - return batch - - ###################### - # TRAIN FUNCTIONS - ###################### - - @staticmethod - def master_params(optimizer: torch.optim.Optimizer): - """Generator over parameters owned by the optimizer. - - Used to select parameters used by the optimizer for gradient clipping. - - Args: - optimizer: Target optimizer. - """ - for group in optimizer.param_groups: - for p in group["params"]: - yield p - - @staticmethod - def _model_train_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None - ) -> Tuple[Dict, Dict]: - """ - Perform a trainig forward step. Compute model outputs and losses. - - Args: - batch (Dict): [description] - model (nn.Module): [description] - criterion (nn.Module): [description] - optimizer_idx (int, optional): [description]. Defaults to None. - - Returns: - Tuple[Dict, Dict]: [description] - """ - input_args = [batch, criterion] - if optimizer_idx is not None: - input_args.append(optimizer_idx) - # unwrap model in DDP training - if hasattr(model, "module"): - return model.module.train_step(*input_args) - return model.train_step(*input_args) - - def _optimize( - self, - batch: Dict, - model: nn.Module, - optimizer: Union[torch.optim.Optimizer, List], - scaler: "AMPScaler", - criterion: nn.Module, - scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access - config: Coqpit, - optimizer_idx: int = None, - ) -> Tuple[Dict, Dict, int]: - """Perform a forward - backward pass and run the optimizer. - - Args: - batch (Dict): Input batch. If - model (nn.Module): Model for training. Defaults to None. - optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. - scaler (AMPScaler): AMP scaler. - criterion (nn.Module): Model's criterion. - scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer. - config (Coqpit): Model config. - optimizer_idx (int, optional): Target optimizer being used. Defaults to None. - - Raises: - RuntimeError: When the loss is NaN. - - Returns: - Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. - """ - - step_start_time = time.time() - # zero-out optimizer - optimizer.zero_grad() - - # forward pass and loss computation - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - if optimizer_idx is not None: - outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) - else: - outputs, loss_dict = self._model_train_step(batch, model, criterion) - - # skip the rest - if outputs is None: - step_time = time.time() - step_start_time - return None, {}, step_time - - # # check nan loss - # if torch.isnan(loss_dict["loss"]).any(): - # raise RuntimeError(f" > NaN loss detected - {loss_dict}") - - # set gradient clipping threshold - if "grad_clip" in config and config.grad_clip is not None: - if optimizer_idx is not None: - grad_clip = config.grad_clip[optimizer_idx] - else: - grad_clip = config.grad_clip - else: - grad_clip = 0.0 # meaning no gradient clipping - - # optimizer step - grad_norm = 0 - update_lr_scheduler = True - if self.use_amp_scaler: - if self.use_apex: - # TODO: verify AMP use for GAN training in TTS - # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers - with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: - scaled_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), grad_clip) - else: - # model optimizer step in mixed precision mode - scaler.scale(loss_dict["loss"]).backward() - if grad_clip > 0: - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip) - scale_prev = scaler.get_scale() - scaler.step(optimizer) - scaler.update() - update_lr_scheduler = scale_prev <= scaler.get_scale() - loss_dict["amp_scaler"] = scaler.get_scale() # for logging - else: - # main model optimizer step - loss_dict["loss"].backward() - if grad_clip > 0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - optimizer.step() - - # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN - if isinstance(grad_norm, torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)): - grad_norm = 0 - - step_time = time.time() - step_start_time - - # setup lr - if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch: - scheduler.step() - - # detach losses - loss_dict = self._detach_loss_dict(loss_dict) - if optimizer_idx is not None: - loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") - loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm - else: - loss_dict["grad_norm"] = grad_norm - return outputs, loss_dict, step_time - - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: - """Perform a training step on a batch of inputs and log the process. - - Args: - batch (Dict): Input batch. - batch_n_steps (int): Number of steps needed to complete an epoch. Needed for logging. - step (int): Current step number in this epoch. - loader_start_time (float): The time when the data loading is started. Needed for logging. - - Returns: - Tuple[Dict, Dict]: Model outputs and losses. - """ - self.callbacks.on_train_step_start(self) - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - - # conteainers to hold model outputs and losses for each optimizer. - outputs_per_optimizer = None - loss_dict = {} - if not isinstance(self.optimizer, list): - # training with a single optimizer - outputs, loss_dict_new, step_time = self._optimize( - batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config - ) - loss_dict.update(loss_dict_new) - else: - # training with multiple optimizers (e.g. GAN) - outputs_per_optimizer = [None] * len(self.optimizer) - total_step_time = 0 - for idx, optimizer in enumerate(self.optimizer): - criterion = self.criterion - # scaler = self.scaler[idx] if self.use_amp_scaler else None - scaler = self.scaler - scheduler = self.scheduler[idx] - outputs, loss_dict_new, step_time = self._optimize( - batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx - ) - # skip the rest if the model returns None - total_step_time += step_time - outputs_per_optimizer[idx] = outputs - # merge loss_dicts from each optimizer - # rename duplicates with the optimizer idx - # if None, model skipped this optimizer - if loss_dict_new is not None: - for k, v in loss_dict_new.items(): - if k in loss_dict: - loss_dict[f"{k}-{idx}"] = v - else: - loss_dict[k] = v - step_time = total_step_time - outputs = outputs_per_optimizer - - # update avg runtime stats - keep_avg_update = {} - keep_avg_update["avg_loader_time"] = loader_time - keep_avg_update["avg_step_time"] = step_time - self.keep_avg_train.update_values(keep_avg_update) - - # update avg loss stats - update_eval_values = {} - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - self.keep_avg_train.update_values(update_eval_values) - - # print training progress - if self.total_steps_done % self.config.print_step == 0: - # log learning rates - lrs = {} - if isinstance(self.optimizer, list): - for idx, optimizer in enumerate(self.optimizer): - current_lr = self.optimizer[idx].param_groups[0]["lr"] - lrs.update({f"current_lr_{idx}": current_lr}) - else: - current_lr = self.optimizer.param_groups[0]["lr"] - lrs = {"current_lr": current_lr} - - # log run-time stats - loss_dict.update(lrs) - loss_dict.update( - { - "step_time": round(step_time, 4), - "loader_time": round(loader_time, 4), - } - ) - self.c_logger.print_train_step( - batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values - ) - - if self.args.rank == 0: - # Plot Training Iter Stats - # reduce TB load and don't log every step - if self.total_steps_done % self.config.plot_step == 0: - self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict) - if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: - if self.config.checkpoint: - # checkpoint the model - target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train) - save_checkpoint( - self.config, - self.model, - self.optimizer, - self.scaler if self.use_amp_scaler else None, - self.total_steps_done, - self.epochs_done, - self.output_path, - model_loss=target_avg_loss, - ) - - if self.total_steps_done % self.config.log_model_step == 0: - # log checkpoint as artifact - aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"] - self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) - - # training visualizations - if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): - self.model.module.train_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - elif hasattr(self.model, "train_log"): - self.model.train_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - - self.dashboard_logger.flush() - - self.total_steps_done += 1 - self.callbacks.on_train_step_end(self) - return outputs, loss_dict - - def train_epoch(self) -> None: - """Main entry point for the training loop. Run training on the all training samples.""" - # initialize the data loader - self.train_loader = self.get_train_dataloader( - self.training_assets, - self.train_samples, - verbose=True, - ) - # set model to training mode - if self.num_gpus > 1: - self.model.module.train() - else: - self.model.train() - epoch_start_time = time.time() - if self.use_cuda: - batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) - else: - batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) - self.c_logger.print_train_start() - loader_start_time = time.time() - # iterate over the training samples - for cur_step, batch in enumerate(self.train_loader): - _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) - loader_start_time = time.time() - epoch_time = time.time() - epoch_start_time - # plot self.epochs_done Stats - if self.args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(self.keep_avg_train.avg_values) - self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats) - if self.config.model_param_stats: - self.logger.model_weights(self.model, self.total_steps_done) - # scheduler step after the epoch - if self.scheduler is not None and self.config.scheduler_after_epoch: - if isinstance(self.scheduler, list): - for scheduler in self.scheduler: - if scheduler is not None: - scheduler.step() - else: - self.scheduler.step() - - ####################### - # EVAL FUNCTIONS - ####################### - - @staticmethod - def _model_eval_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None - ) -> Tuple[Dict, Dict]: - """ - Perform a evaluation forward pass. Compute model outputs and losses with no gradients. - - Args: - batch (Dict): IBatch of inputs. - model (nn.Module): Model to call evaluation. - criterion (nn.Module): Model criterion. - optimizer_idx (int, optional): Optimizer ID to define the closure in multi-optimizer training. Defaults to None. - - Returns: - Tuple[Dict, Dict]: model outputs and losses. - """ - input_args = [batch, criterion] - if optimizer_idx is not None: - input_args.append(optimizer_idx) - if hasattr(model, "module"): - return model.module.eval_step(*input_args) - return model.eval_step(*input_args) - - def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: - """Perform a evaluation step on a batch of inputs and log the process. - - Args: - batch (Dict): Input batch. - step (int): Current step number in this epoch. - - Returns: - Tuple[Dict, Dict]: Model outputs and losses. - """ - with torch.no_grad(): - outputs = [] - loss_dict = {} - if not isinstance(self.optimizer, list): - outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) - else: - outputs = [None] * len(self.optimizer) - for idx, _ in enumerate(self.optimizer): - criterion = self.criterion - outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) - outputs[idx] = outputs_ - - if loss_dict_new is not None: - loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss") - loss_dict.update(loss_dict_new) - - loss_dict = self._detach_loss_dict(loss_dict) - - # update avg stats - update_eval_values = {} - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - self.keep_avg_eval.update_values(update_eval_values) - - if self.config.print_eval: - self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) - return outputs, loss_dict - - def eval_epoch(self) -> None: - """Main entry point for the evaluation loop. Run evaluation on the all validation samples.""" - self.eval_loader = ( - self.get_eval_dataloader( - self.training_assets, - self.eval_samples, - verbose=True, - ) - if self.config.run_eval - else None - ) - - self.model.eval() - self.c_logger.print_eval_start() - loader_start_time = time.time() - batch = None - for cur_step, batch in enumerate(self.eval_loader): - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) - outputs, _ = self.eval_step(batch, cur_step) - loader_start_time = time.time() - # plot epoch stats, artifacts and figures - if self.args.rank == 0: - if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"): - self.model.module.eval_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - elif hasattr(self.model, "eval_log"): - self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) - self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) - - def test_run(self) -> None: - """Run test and log the results. Test run must be defined by the model. - Model must return figures and audios to be logged by the Tensorboard.""" - if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): - if self.eval_loader is None: - self.eval_loader = self.get_eval_dataloader( - self.training_assets, - self.eval_samples, - verbose=True, - ) - - if hasattr(self.eval_loader.dataset, "load_test_samples"): - samples = self.eval_loader.dataset.load_test_samples(1) - if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.training_assets, samples, None) - else: - figures, audios = self.model.test_run(self.training_assets, samples, None) - else: - if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.training_assets) - else: - figures, audios = self.model.test_run(self.training_assets) - self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) - self.dashboard_logger.test_figures(self.total_steps_done, figures) - - def _restore_best_loss(self): - """Restore the best loss from the args.best_path if provided else - from the model (`args.restore_path` or `args.continue_path`) used for resuming the training""" - if self.restore_step != 0 or self.args.best_path: - print(f" > Restoring best loss from {os.path.basename(self.args.best_path)} ...") - ch = load_fsspec(self.args.restore_path, map_location="cpu") - if "model_loss" in ch: - self.best_loss = ch["model_loss"] - print(f" > Starting with loaded last best loss {self.best_loss}.") - - ################################### - # FIT FUNCTIONS - ################################### - - def _fit(self) -> None: - """🏃 train -> evaluate -> test for the number of epochs.""" - self._restore_best_loss() - - self.total_steps_done = self.restore_step - - for epoch in range(0, self.config.epochs): - if self.num_gpus > 1: - # let all processes sync up before starting with a new epoch of training - dist.barrier() - self.callbacks.on_epoch_start(self) - self.keep_avg_train = KeepAverage() - self.keep_avg_eval = KeepAverage() if self.config.run_eval else None - self.epochs_done = epoch - self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path) - if not self.args.skip_train_epoch: - self.train_epoch() - if self.config.run_eval: - self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank <= 0: - self.test_run() - self.c_logger.print_epoch_end( - epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values - ) - if self.args.rank in [None, 0]: - self.save_best_model() - self.callbacks.on_epoch_end(self) - - def fit(self) -> None: - """Where the ✨️magic✨️ happens...""" - try: - self._fit() - if self.args.rank == 0: - self.dashboard_logger.finish() - except KeyboardInterrupt: - self.callbacks.on_keyboard_interrupt(self) - # if the output folder is empty remove the run. - remove_experiment_folder(self.output_path) - # clear the DDP processes - if self.num_gpus > 1: - dist.destroy_process_group() - # finish the wandb run and sync data - if self.args.rank == 0: - self.dashboard_logger.finish() - # stop without error signal - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except BaseException: # pylint: disable=broad-except - remove_experiment_folder(self.output_path) - traceback.print_exc() - sys.exit(1) - - def save_best_model(self) -> None: - """Save the best model. It only saves if the current target loss is smaller then the previous.""" - - # set the target loss to choose the best model - target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train) - - # save the model and update the best_loss - self.best_loss = save_best_model( - target_loss_dict, - self.best_loss, - self.config, - self.model, - self.optimizer, - self.scaler if self.use_amp_scaler else None, - self.total_steps_done, - self.epochs_done, - self.output_path, - keep_all_best=self.config.keep_all_best, - keep_after=self.config.keep_after, - ) - - ##################### - # GET FUNCTIONS - ##################### - - @staticmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: - """Receive the optimizer from the model if model implements `get_optimizer()` else - check the optimizer parameters in the config and try initiating the optimizer. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list. - """ - if hasattr(model, "get_optimizer"): - optimizer = model.get_optimizer() - if optimizer is None: - optimizer_name = config.optimizer - optimizer_params = config.optimizer_params - return get_optimizer(optimizer_name, optimizer_params, config.lr, model) - return optimizer - - @staticmethod - def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: - """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate - fromthe config. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier. - """ - lr = None - if hasattr(model, "get_lr"): - lr = model.get_lr() - if lr is None: - lr = config.lr - return lr - - @staticmethod - def get_scheduler( - model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List] - ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access - """Receive the scheduler from the model if model implements `get_scheduler()` else - check the config and try initiating the scheduler. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer. - """ - scheduler = None - if hasattr(model, "get_scheduler"): - scheduler = model.get_scheduler(optimizer) - if scheduler is None: - lr_scheduler = config.lr_scheduler - lr_scheduler_params = config.lr_scheduler_params - return get_scheduler(lr_scheduler, lr_scheduler_params, optimizer) - return scheduler - - @staticmethod - def get_criterion(model: nn.Module) -> nn.Module: - """Receive the criterion from the model. Model must implement `get_criterion()`. - - Args: - model (nn.Module): Training model. - - Returns: - nn.Module: Criterion layer. - """ - criterion = None - criterion = model.get_criterion() - return criterion - - #################### - # HELPER FUNCTIONS - #################### - - @staticmethod - def _detach_loss_dict(loss_dict: Dict) -> Dict: - """Detach loss values from autograp. - - Args: - loss_dict (Dict): losses. - - Returns: - Dict: losses detached from autograph. - """ - loss_dict_detached = {} - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_detached[key] = value - else: - loss_dict_detached[key] = value.detach().item() - return loss_dict_detached - - def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: - """Pick the target loss to compare models""" - target_avg_loss = None - - # return if target loss defined in the model config - if "target_loss" in self.config and self.config.target_loss: - return keep_avg_target[f"avg_{self.config.target_loss}"] - - # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers - if isinstance(self.optimizer, list): - target_avg_loss = 0 - for idx in range(len(self.optimizer)): - target_avg_loss += keep_avg_target[f"avg_loss_{idx}"] - target_avg_loss /= len(self.optimizer) - else: - target_avg_loss = keep_avg_target["avg_loss"] - return target_avg_loss - - def _setup_logger_config(self, log_file: str) -> None: - """Write log strings to a file and print logs to the terminal. - TODO: Causes formatting issues in pdb debugging.""" - - class Logger(object): - def __init__(self, print_to_terminal=True): - self.print_to_terminal = print_to_terminal - self.terminal = sys.stdout - self.log_file = log_file - - def write(self, message): - if self.print_to_terminal: - self.terminal.write(message) - with open(self.log_file, "a", encoding="utf-8") as f: - f.write(message) - - def flush(self): - # this flush method is needed for python 3 compatibility. - # this handles the flush command by doing nothing. - # you might want to specify some extra behavior here. - pass - - # don't let processes rank > 0 write to the terminal - sys.stdout = Logger(self.args.rank == 0) - - @staticmethod - def _is_apex_available() -> bool: - """Check if Nvidia's APEX is available.""" - return importlib.util.find_spec("apex") is not None diff --git a/TTS/utils/logging/__init__.py b/TTS/utils/logging/__init__.py deleted file mode 100644 index 43fbf6f1..00000000 --- a/TTS/utils/logging/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from TTS.utils.logging.console_logger import ConsoleLogger -from TTS.utils.logging.tensorboard_logger import TensorboardLogger -from TTS.utils.logging.wandb_logger import WandbLogger - - -def init_dashboard_logger(config): - if config.dashboard_logger == "tensorboard": - dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model) - - elif config.dashboard_logger == "wandb": - project_name = config.model - if config.project_name: - project_name = config.project_name - - dashboard_logger = WandbLogger( - project=project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) - - dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) - - return dashboard_logger diff --git a/TTS/utils/logging/console_logger.py b/TTS/utils/logging/console_logger.py deleted file mode 100644 index 74371342..00000000 --- a/TTS/utils/logging/console_logger.py +++ /dev/null @@ -1,105 +0,0 @@ -import datetime - -from TTS.utils.io import AttrDict - -tcolors = AttrDict( - { - "OKBLUE": "\033[94m", - "HEADER": "\033[95m", - "OKGREEN": "\033[92m", - "WARNING": "\033[93m", - "FAIL": "\033[91m", - "ENDC": "\033[0m", - "BOLD": "\033[1m", - "UNDERLINE": "\033[4m", - } -) - - -class ConsoleLogger: - def __init__(self): - # TODO: color code for value changes - # use these to compare values between iterations - self.old_train_loss_dict = None - self.old_epoch_loss_dict = None - self.old_eval_loss_dict = None - - # pylint: disable=no-self-use - def get_time(self): - now = datetime.datetime.now() - return now.strftime("%Y-%m-%d %H:%M:%S") - - def print_epoch_start(self, epoch, max_epoch, output_path=None): - print( - "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), - flush=True, - ) - if output_path is not None: - print(f" --> {output_path}") - - def print_train_start(self): - print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - - def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict): - indent = " | > " - print() - log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( - tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC - ) - for key, value in loss_dict.items(): - if f"avg_{key}" in avg_loss_dict.keys(): - # print the avg value if given - if isinstance(value, float) and round(value, 5) == 0: - # do not round the number if it is zero when rounded - log_text += "{}{}: {} ({})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - # print the rounded value - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - if isinstance(value, float) and round(value, 5) == 0: - log_text += "{}{}: {} \n".format(indent, key, value) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - print(log_text, flush=True) - - # pylint: disable=unused-argument - def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): - indent = " | > " - log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" - for key, value in print_dict.items(): - log_text += "{}{}: {:.5f}\n".format(indent, key, value) - print(log_text, flush=True) - - def print_eval_start(self): - print(f"\n{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n") - - def print_eval_step(self, step, loss_dict, avg_loss_dict): - indent = " | > " - log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" - for key, value in loss_dict.items(): - # print the avg value if given - if f"avg_{key}" in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - print(log_text, flush=True) - - def print_epoch_end(self, epoch, avg_loss_dict): - indent = " | > " - log_text = "\n {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) - for key, value in avg_loss_dict.items(): - # print the avg value if given - color = "" - sign = "+" - diff = 0 - if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: - diff = value - self.old_eval_loss_dict[key] - if diff < 0: - color = tcolors.OKGREEN - sign = "" - elif diff > 0: - color = tcolors.FAIL - sign = "+" - log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) - self.old_eval_loss_dict = avg_loss_dict - print(log_text, flush=True) diff --git a/TTS/utils/logging/tensorboard_logger.py b/TTS/utils/logging/tensorboard_logger.py deleted file mode 100644 index 812683f7..00000000 --- a/TTS/utils/logging/tensorboard_logger.py +++ /dev/null @@ -1,79 +0,0 @@ -import traceback - -from tensorboardX import SummaryWriter - - -class TensorboardLogger(object): - def __init__(self, log_dir, model_name): - self.model_name = model_name - self.writer = SummaryWriter(log_dir) - - def model_weights(self, model, step): - layer_num = 1 - for name, param in model.named_parameters(): - if param.numel() == 1: - self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) - else: - self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) - self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) - self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) - self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) - self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) - layer_num += 1 - - def dict_to_tb_scalar(self, scope_name, stats, step): - for key, value in stats.items(): - self.writer.add_scalar("{}/{}".format(scope_name, key), value, step) - - def dict_to_tb_figure(self, scope_name, figures, step): - for key, value in figures.items(): - self.writer.add_figure("{}/{}".format(scope_name, key), value, step) - - def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): - for key, value in audios.items(): - if value.dtype == "float16": - value = value.astype("float32") - try: - self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) - except RuntimeError: - traceback.print_exc() - - def train_step_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) - - def train_epoch_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step) - - def train_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step) - - def train_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate) - - def eval_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step) - - def eval_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step) - - def eval_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate) - - def test_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate) - - def test_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step) - - def add_text(self, title, text, step): - self.writer.add_text(title, text, step) - - def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201 - yield - - def flush(self): - self.writer.flush() - - def finish(self): - self.writer.close() diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py deleted file mode 100644 index 5fcab00f..00000000 --- a/TTS/utils/logging/wandb_logger.py +++ /dev/null @@ -1,111 +0,0 @@ -# pylint: disable=W0613 - -import traceback -from pathlib import Path - -try: - import wandb - from wandb import finish, init # pylint: disable=W0611 -except ImportError: - wandb = None - - -class WandbLogger: - def __init__(self, **kwargs): - - if not wandb: - raise Exception("install wandb using `pip install wandb` to use WandbLogger") - - self.run = None - self.run = wandb.init(**kwargs) if not wandb.run else wandb.run - self.model_name = self.run.config.model - self.log_dict = {} - - def model_weights(self, model): - layer_num = 1 - for name, param in model.named_parameters(): - if param.numel() == 1: - self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()}) - else: - self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}) - self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}) - self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}) - self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}) - self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param) - self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad) - layer_num += 1 - - def dict_to_scalar(self, scope_name, stats): - for key, value in stats.items(): - self.log_dict["{}/{}".format(scope_name, key)] = value - - def dict_to_figure(self, scope_name, figures): - for key, value in figures.items(): - self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value) - - def dict_to_audios(self, scope_name, audios, sample_rate): - for key, value in audios.items(): - if value.dtype == "float16": - value = value.astype("float32") - try: - self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate) - except RuntimeError: - traceback.print_exc() - - def log(self, log_dict, prefix="", flush=False): - for key, value in log_dict.items(): - self.log_dict[prefix + key] = value - if flush: # for cases where you don't want to accumulate data - self.flush() - - def train_step_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats) - - def train_epoch_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats) - - def train_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_TrainFigures", figures) - - def train_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate) - - def eval_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_EvalStats", stats) - - def eval_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_EvalFigures", figures) - - def eval_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate) - - def test_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate) - - def test_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_TestFigures", figures) - - def add_text(self, title, text, step): - pass - - def flush(self): - if self.run: - wandb.log(self.log_dict) - self.log_dict = {} - - def finish(self): - if self.run: - self.run.finish() - - def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): - if not self.run: - return - name = "_".join([self.run.id, name]) - artifact = wandb.Artifact(name, type=artifact_type) - data_path = Path(file_or_dir) - if data_path.is_dir(): - artifact.add_dir(str(data_path)) - elif data_path.is_file(): - artifact.add_file(str(data_path)) - - self.run.log_artifact(artifact, aliases=aliases) From 590b04fb89223d6350c98a50a6459d9b70c12799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:36:32 +0100 Subject: [PATCH 152/214] Fix espeak_wrapper --- .../utils/text/phonemizers/espeak_wrapper.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 3cccee41..2fe0c39c 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -11,7 +11,7 @@ def is_tool(name): return which(name) is not None - +# priority: espeakng > espeak if is_tool("espeak-ng"): _DEF_ESPEAK_LIB = "espeak-ng" elif is_tool("espeak"): @@ -21,6 +21,7 @@ else: def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: + """Run espeak with the given arguments.""" cmd = [ espeak_lib, "-q", @@ -85,7 +86,8 @@ class ESpeak(BasePhonemizer): def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): if self._ESPEAK_LIB is None: - raise Exception("Unknown backend: %s" % backend) + raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.") + self.backend = self._ESPEAK_LIB # band-aid for backwards compatibility if language == "en": @@ -104,6 +106,16 @@ class ESpeak(BasePhonemizer): if backend not in ["espeak", "espeak-ng"]: raise Exception("Unknown backend: %s" % backend) self._ESPEAK_LIB = backend + # skip first two characters of the retuned text + # "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # ^^ + self.num_skip_chars = 2 + if backend == "espeak-ng": + # skip the first character of the retuned text + # "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # ^ + self.num_skip_chars = 1 + def auto_set_espeak_lib(self) -> None: if is_tool("espeak-ng"): @@ -151,7 +163,7 @@ class ESpeak(BasePhonemizer): phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): logging.debug("line: %s", repr(line)) - phonemes += line.decode("utf8").strip()[2:] # skip two redundant characters + phonemes += line.decode("utf8").strip()[self.num_skip_chars:] # skip initial redundant characters return phonemes.replace("_", separator) def _phonemize(self, text, separator=None): From 54c6bb2a8cce664ce2492ecb946c61f830e78965 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:37:13 +0100 Subject: [PATCH 153/214] Fix add speaker VITS --- TTS/tts/models/vits.py | 53 ++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a69b02ba..f02090cf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -799,21 +799,7 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. - - Args: - batch (Dict): Input tensors. - criterion (nn.Module): Loss layer designed for the model. - optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. - - Returns: - Tuple[Dict, Dict]: Model ouputs and computed losses. - """ - # pylint: disable=attribute-defined-outside-init - if optimizer_idx not in [0, 1]: - raise ValueError(" [!] Unexpected `optimizer_idx`.") - + def _freeze_layers(self): if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False @@ -838,6 +824,24 @@ class Vits(BaseTTS): for param in self.waveform_decoder.parameters(): param.requires_grad = False + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + # pylint: disable=attribute-defined-outside-init + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + self._freeze_layers() + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -848,6 +852,9 @@ class Vits(BaseTTS): language_ids = batch["language_ids"] waveform = batch["waveform"] + # if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0: + # breakpoint() + # generator pass outputs = self.forward( text_input, @@ -859,8 +866,6 @@ class Vits(BaseTTS): ) # cache tensors for the discriminator - self.y_disc_cache = None - self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] self.wav_seg_disc_cache = outputs["waveform_seg"] @@ -888,6 +893,9 @@ class Vits(BaseTTS): syn_spk_emb=outputs["syn_spk_emb"], ) + # if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0: + # breakpoint() + elif optimizer_idx == 1: # discriminator pass outputs = {} @@ -984,7 +992,11 @@ class Vits(BaseTTS): test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) - return test_figures, test_audios + return {"figures": test_figures, "audios": test_audios} + + def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None: + logger.test_audios(steps, outputs['audios'], self.ap.sample_rate) + logger.test_figures(steps, outputs['figures']) def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. @@ -1056,9 +1068,10 @@ class Vits(BaseTTS): state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: - print(" > Loading checkpoint with additional speakers.") + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") emb_g = state["model"]["emb_g.weight"] - new_row = torch.zeros(1, emb_g.shape[1]) + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) state["model"]["emb_g.weight"] = emb_g From d3a58ed07a1820f3f591e68a4a4f71d36ea2924d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:37:30 +0100 Subject: [PATCH 154/214] Fix default values --- TTS/tts/configs/vits_config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index d306552d..a8c7f91d 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -17,7 +17,7 @@ class VitsConfig(BaseTTSConfig): Model architecture arguments. Defaults to `VitsArgs()`. grad_clip (List): - Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`. + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. lr_gen (float): Initial learning rate for the generator. Defaults to 0.0002. @@ -114,7 +114,6 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 - aligner_loss_alpha = 1.0 speaker_encoder_loss_alpha: float = 1.0 # data loader params From aa8145472162d2956325d2d31c2457eb162c6756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:37:57 +0100 Subject: [PATCH 155/214] Update BaseTrainingConfig --- TTS/config/shared_configs.py | 113 ++--------------------------------- 1 file changed, 4 insertions(+), 109 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 392f10af..6394b264 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass from typing import List from coqpit import Coqpit, check_argument +from trainer import TrainerConfig @dataclass @@ -237,130 +238,24 @@ class BaseDatasetConfig(Coqpit): @dataclass -class BaseTrainingConfig(Coqpit): - """Base config to define the basic training parameters that are shared - among all the models. +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. Args: model (str): Name of the model that is used in the training. - run_name (str): - Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. - - run_description (str): - Short description of the experiment. - - epochs (int): - Number training epochs. Defaults to 10000. - - batch_size (int): - Training batch size. - - eval_batch_size (int): - Validation batch size. - - mixed_precision (bool): - Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however - it may also cause numerical unstability in some cases. - - scheduler_after_epoch (bool): - If true, run the scheduler step after each epoch else run it after each model step. - - run_eval (bool): - Enable / Disable evaluation (validation) run. Defaults to True. - - test_delay_epochs (int): - Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful - results, hence waiting for a couple of epochs might save some time. - - print_eval (bool): - Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at - the end of the evaluation. Default to ```False```. - - print_step (int): - Number of steps required to print the next training log. - - log_dashboard (str): "tensorboard" or "wandb" - Set the experiment tracking tool - - plot_step (int): - Number of steps required to log training on Tensorboard. - - model_param_stats (bool): - Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. - Defaults to ```False```. - - project_name (str): - Name of the project. Defaults to config.model - - wandb_entity (str): - Name of W&B entity/team. Enables collaboration across a team or org. - - log_model_step (int): - Number of steps required to log a checkpoint as W&B artifact - - save_step (int): - Number of steps required to save the next checkpoint. - - checkpoint (bool): - Enable / Disable checkpointing. - - keep_all_best (bool): - Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults - to ```False```. - - keep_after (int): - Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults - to 10000. - num_loader_workers (int): Number of workers for training time dataloader. num_eval_loader_workers (int): Number of workers for evaluation time dataloader. - - output_path (str): - Path for training output folder, either a local file path or other - URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or - S3 (s3://) paths. The nonexist part of the given path is created - automatically. All training artefacts are saved there. """ model: str = None - run_name: str = "coqui_tts" - run_description: str = "" - # training params - epochs: int = 10000 - batch_size: int = None - eval_batch_size: int = None - mixed_precision: bool = False - scheduler_after_epoch: bool = False - # eval params - run_eval: bool = True - test_delay_epochs: int = 0 - print_eval: bool = False - # logging - dashboard_logger: str = "tensorboard" - print_step: int = 25 - plot_step: int = 100 - model_param_stats: bool = False - project_name: str = None - log_model_step: int = None - wandb_entity: str = None - # checkpointing - save_step: int = 10000 - checkpoint: bool = True - keep_all_best: bool = False - keep_after: int = 10000 # dataloading num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False use_language_weighted_sampler: bool = False - - # paths - output_path: str = None - # distributed - distributed_backend: str = "nccl" - distributed_url: str = "tcp://localhost:54321" From 27db089d6c27fb5a58513abd73c5451fac981a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:42:12 +0100 Subject: [PATCH 156/214] Change TrainingArgs -> TrainerArgs --- TTS/bin/train_tts.py | 4 ++-- TTS/bin/train_vocoder.py | 4 ++-- recipes/ljspeech/align_tts/train_aligntts.py | 4 ++-- recipes/ljspeech/fast_pitch/train_fast_pitch.py | 4 ++-- recipes/ljspeech/fast_speech/train_fast_speech.py | 4 ++-- recipes/ljspeech/glow_tts/train_glowtts.py | 4 ++-- recipes/ljspeech/hifigan/train_hifigan.py | 4 ++-- .../ljspeech/multiband_melgan/train_multiband_melgan.py | 4 ++-- recipes/ljspeech/speedy_speech/train_speedy_speech.py | 4 ++-- recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py | 4 ++-- recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py | 4 ++-- recipes/ljspeech/univnet/train.py | 4 ++-- recipes/ljspeech/vits_tts/train_vits.py | 9 ++++----- recipes/ljspeech/wavegrad/train_wavegrad.py | 4 ++-- recipes/ljspeech/wavernn/train_wavernn.py | 4 ++-- recipes/multilingual/vits_tts/train_vits_tts.py | 4 ++-- recipes/vctk/fast_pitch/train_fast_pitch.py | 4 ++-- recipes/vctk/fast_speech/train_fast_speech.py | 4 ++-- recipes/vctk/glow_tts/train_glow_tts.py | 4 ++-- recipes/vctk/speedy_speech/train_speedy_speech.py | 4 ++-- recipes/vctk/tacotron-DDC/train_tacotron-DDC.py | 4 ++-- recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py | 4 ++-- recipes/vctk/tacotron2/train_tacotron2.py | 4 ++-- recipes/vctk/vits/train_vits.py | 4 ++-- 24 files changed, 50 insertions(+), 51 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 824f0128..79b78767 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,7 +1,7 @@ import os from TTS.config import load_config, register_config -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model @@ -9,7 +9,7 @@ from TTS.tts.models import setup_model def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainerArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index cd665f29..081fdd56 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,7 +1,7 @@ import os from TTS.config import load_config, register_config -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model @@ -10,7 +10,7 @@ from TTS.vocoder.models import setup_model def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainerArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index d0187aa8..a4b868aa 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.align_tts_config import AlignTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples @@ -57,7 +57,7 @@ model = AlignTTS(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 3a772251..fcb62282 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -90,6 +90,6 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index f9f1bc06..183c8ebb 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -89,6 +89,6 @@ model = ForwardTTS(config, ap, tokenizer) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index dd450a57..c47cd00a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -2,7 +2,7 @@ import os # Trainer: Where the ✨️ happens. # TrainingArgs: Defines the set of arguments of the Trainer. -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs # GlowTTSConfig: all model related values for training, validating and testing. from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -72,7 +72,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index 8d1c272a..964a6420 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +40,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 90c52997..6f528a83 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +40,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 2f8896c5..6a9ddf16 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -75,7 +75,7 @@ model = ForwardTTS(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index a7f037e6..c3a1c51c 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -88,7 +88,7 @@ model = Tacotron2(config, ap, tokenizer) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index 285c416c..a7482b32 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -83,7 +83,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 589fd027..35240c5b 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +39,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 79c0db2e..24ff4d0f 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -33,7 +33,7 @@ audio_config = BaseAudioConfig( config = VitsConfig( audio=audio_config, run_name="vits_ljspeech", - batch_size=16, + batch_size=32, eval_batch_size=16, batch_group_size=5, num_loader_workers=0, @@ -48,8 +48,7 @@ config = VitsConfig( compute_input_seq_cache=True, print_step=25, print_eval=True, - mixed_precision=False, - max_seq_len=500000, + mixed_precision=True, output_path=output_path, datasets=[dataset_config], ) @@ -76,7 +75,7 @@ model = Vits(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 6786c052..095773d6 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -37,7 +37,7 @@ model = Wavegrad(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index f64f5752..172b489a 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,6 +1,6 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +39,7 @@ model = Wavernn(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index be4747df..391f31cb 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -2,7 +2,7 @@ import os from glob import glob from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -119,7 +119,7 @@ model = Vits(config, speaker_manager, language_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index 4d9cc10d..aeb62055 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -85,7 +85,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 1dcab982..578fbd1a 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index e35e552d..0f198a86 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples @@ -83,7 +83,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index 85e347fc..fbb1af2d 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -1,7 +1,7 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS @@ -83,7 +83,7 @@ model = ForwardTTS(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index 7960b34b..917c5588 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.datasets import load_tts_samples @@ -85,7 +85,7 @@ model = Tacotron(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index bc7951b5..759ddd57 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 82dedade..0c62da48 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples @@ -91,7 +91,7 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index caf1caa1..53d7242c 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -1,7 +1,7 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -90,7 +90,7 @@ model = Vits(config, ap, tokenizer, speaker_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, From 33b98e6cc3b1358b6c6a116149def3142d7807d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:48:16 +0100 Subject: [PATCH 157/214] Update requirements.txt --- requirements.txt | 48 ++++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index ddb6def9..54a6bdfd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,30 +1,38 @@ -cython -flask -gdown -inflect -jieba -librosa==0.8.0 -matplotlib +# core deps numpy==1.19.5 -pandas -pypinyin -pysbd -pyyaml +cython scipy>=0.19.0 -soundfile -tensorboardX torch>=1.7 -tqdm +torchaudio +soundfile +librosa==0.8.0 numba==0.53 -umap-learn==0.5.1 +inflect +tqdm anyascii -coqpit +pyyaml +fsspec>=2021.04.0 +# deps for examples +flask +# deps for inference +pysbd +# deps for notebooks +umap-learn==0.5.1 +pandas +# deps for training +matplotlib +tensorboardX +pyworld +# coqui stack +git+https://github.com/coqui-ai/Trainer@main # trainer +coqpit # config managemenr +# chinese g2p deps +jieba +pypinyin # japanese g2p deps mecab-python3==1.0.3 unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 -fsspec>=2021.04.0 -pyworld -webrtcvad -torchaudio +# others +webrtcvad # for VAD From 8622226f3f26a8656767eb538854940f95dd8c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:51:16 +0100 Subject: [PATCH 158/214] Make style --- .gitignore | 3 ++- TTS/bin/train_tts.py | 3 ++- TTS/bin/train_vocoder.py | 3 ++- TTS/tts/models/base_tts.py | 4 ++-- TTS/tts/models/vits.py | 6 +++--- TTS/tts/utils/text/phonemizers/espeak_wrapper.py | 4 ++-- recipes/ljspeech/align_tts/train_aligntts.py | 1 + recipes/ljspeech/fast_pitch/train_fast_pitch.py | 3 ++- recipes/ljspeech/fast_speech/train_fast_speech.py | 3 ++- recipes/ljspeech/hifigan/train_hifigan.py | 1 + recipes/ljspeech/multiband_melgan/train_multiband_melgan.py | 1 + recipes/ljspeech/speedy_speech/train_speedy_speech.py | 3 ++- recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py | 3 ++- recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py | 3 ++- recipes/ljspeech/univnet/train.py | 1 + recipes/ljspeech/vits_tts/train_vits.py | 3 ++- recipes/ljspeech/wavegrad/train_wavegrad.py | 1 + recipes/ljspeech/wavernn/train_wavernn.py | 1 + recipes/multilingual/vits_tts/train_vits_tts.py | 3 ++- recipes/vctk/fast_pitch/train_fast_pitch.py | 3 ++- recipes/vctk/fast_speech/train_fast_speech.py | 3 ++- recipes/vctk/glow_tts/train_glow_tts.py | 3 ++- recipes/vctk/speedy_speech/train_speedy_speech.py | 3 ++- recipes/vctk/tacotron-DDC/train_tacotron-DDC.py | 3 ++- recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py | 3 ++- recipes/vctk/tacotron2/train_tacotron2.py | 3 ++- recipes/vctk/vits/train_vits.py | 5 +++-- 27 files changed, 50 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 7e9da0d8..f8d6e644 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ internal/* *_pitch.npy *_phoneme.npy wandb -depot/* \ No newline at end of file +depot/* +coqui_recipes/* \ No newline at end of file diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 79b78767..73063731 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,7 +1,8 @@ import os -from TTS.config import load_config, register_config from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 081fdd56..6d4df610 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,7 +1,8 @@ import os -from TTS.config import load_config, register_config from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config from TTS.utils.audio import AudioProcessor from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 0eb2b5f3..ca3c4a28 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -324,9 +324,9 @@ class BaseTTS(BaseModel): loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. + shuffle=False, # shuffle is done in the dataset. collate_fn=dataset.collate_fn, - drop_last=True, # setting this False might cause issues in AMP training. + drop_last=True, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index f02090cf..256ea3af 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -994,9 +994,9 @@ class Vits(BaseTTS): print(" !! Error creating Test Sentence -", idx) return {"figures": test_figures, "audios": test_audios} - def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None: - logger.test_audios(steps, outputs['audios'], self.ap.sample_rate) - logger.test_figures(steps, outputs['figures']) + def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 2fe0c39c..442dcef2 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -11,6 +11,7 @@ def is_tool(name): return which(name) is not None + # priority: espeakng > espeak if is_tool("espeak-ng"): _DEF_ESPEAK_LIB = "espeak-ng" @@ -116,7 +117,6 @@ class ESpeak(BasePhonemizer): # ^ self.num_skip_chars = 1 - def auto_set_espeak_lib(self) -> None: if is_tool("espeak-ng"): self._ESPEAK_LIB = "espeak-ng" @@ -163,7 +163,7 @@ class ESpeak(BasePhonemizer): phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): logging.debug("line: %s", repr(line)) - phonemes += line.decode("utf8").strip()[self.num_skip_chars:] # skip initial redundant characters + phonemes += line.decode("utf8").strip()[self.num_skip_chars :] # skip initial redundant characters return phonemes.replace("_", separator) def _phonemize(self, text, separator=None): diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index a4b868aa..f1b29025 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.tts.configs.align_tts_config import AlignTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index fcb62282..a3fc35c9 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index 183c8ebb..560d3de2 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -1,7 +1,8 @@ import os -from TTS.config import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index 964a6420..1e5bbf30 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 6f528a83..40ff5a00 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.datasets.preprocess import load_wav_data diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 6a9ddf16..7ad132b2 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -1,7 +1,8 @@ import os -from TTS.config import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index c3a1c51c..ea1b0874 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index a7482b32..d00f8ed7 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 35240c5b..19c91925 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.datasets.preprocess import load_wav_data diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 24ff4d0f..cfb3351d 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 095773d6..1abdf45d 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.datasets.preprocess import load_wav_data diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 172b489a..640f5092 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,6 +1,7 @@ import os from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.datasets.preprocess import load_wav_data diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 391f31cb..ea4f377b 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -1,8 +1,9 @@ import os from glob import glob -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index aeb62055..986202c5 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -1,7 +1,8 @@ import os -from TTS.config import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 578fbd1a..fe785a41 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -1,7 +1,8 @@ import os -from TTS.config import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 0f198a86..ebdbfb37 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index fbb1af2d..80d21ca2 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -1,7 +1,8 @@ import os -from TTS.config import BaseAudioConfig, BaseDatasetConfig from trainer import Trainer, TrainerArgs + +from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index 917c5588..bed21ad9 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index 759ddd57..caa745b3 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 0c62da48..43f5d4e6 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 53d7242c..dff4eefc 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -1,7 +1,8 @@ import os -from TTS.config.shared_configs import BaseAudioConfig from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -57,7 +58,7 @@ config = VitsConfig( print_step=25, print_eval=False, mixed_precision=True, - max_text_len= 325, # change this if you have a larger VRAM than 16GB + max_text_len=325, # change this if you have a larger VRAM than 16GB output_path=output_path, datasets=[dataset_config], ) From ab8a4ca2c3dfac43374922777153fe42f04002c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:29:40 +0100 Subject: [PATCH 159/214] Revert random segment --- TTS/tts/models/vits.py | 3 +-- TTS/tts/utils/helpers.py | 51 +++++++++++++--------------------------- 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 256ea3af..751187ea 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -656,14 +656,13 @@ class Vits(BaseTTS): logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c2e7f561..9ccb5d62 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): """Segment each sample in a batch based on the provided segment indices Args: @@ -66,25 +66,16 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ segment_size (int): Expected output segment size. pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ - # pad the input tensor if it is shorter than the segment size - if pad_short and x.shape[-1] < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) - - segments = torch.zeros_like(x[:, :, :segment_size]) - + ret = torch.zeros_like(x[:, :, :segment_size]) for i in range(x.size(0)): - index_start = segment_indices[i] - index_end = index_start + segment_size - x_i = x[i] - if pad_short and index_end > x.size(2): - # pad the sample if it is shorter than the segment size - x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) - segments[i] = x_i[:, index_start:index_end] - return segments + idx_str = segment_indices[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret def rand_segments( - x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 ): """Create random segments based on the input lengths. @@ -99,25 +90,15 @@ def rand_segments( - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ - _x_lenghts = x_lengths.clone() - B, _, T = x.size() - if pad_short: - if T < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - T)) - T = segment_size - if _x_lenghts is None: - _x_lenghts = T - len_diff = _x_lenghts - segment_size + 1 - if let_short_samples: - _x_lenghts[len_diff < 0] = segment_size - len_diff = _x_lenghts - segment_size + 1 - else: - assert all( - len_diff > 0 - ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" - segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() - ret = segment(x, segment_indices, segment_size) - return ret, segment_indices + b, _, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + if (ids_str_max < 0).sum(): + raise ValueError("Segment size is larger than the input length.") + ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long() + ret = segment(x, ids_str, segment_size) + return ret, ids_str def average_over_durations(values, durs): From 4b96bfe92568fbf5e25d9b0bfa595537ee319c92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:30:55 +0100 Subject: [PATCH 160/214] Fix train logging --- TTS/tts/models/vits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 751187ea..7dac1bb9 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -947,8 +947,8 @@ class Vits(BaseTTS): Tuple[Dict, np.ndarray]: training plots and output waveform. """ figures, audios = self._log(self.ap, batch, outputs, "train") - logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, self.ap.sample_rate) + logger.train_figures(steps, figures) + logger.train_figures(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): From 1a43e05460196b8dd167df12c3a9bcb499b97a1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:33:36 +0100 Subject: [PATCH 161/214] Fix VITS loss bug Fake and real features were given in the wrong args order to the loss function --- TTS/tts/layers/losses.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 827da751..0c94f91f 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module): rl = rl.float().detach() gl = gl.float() loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 @staticmethod @@ -629,9 +628,16 @@ class VitsGeneratorLoss(nn.Module): mel_hat = self.stft(waveform_hat) # compute losses - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha - loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha - loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha + loss_kl = self.kl_loss( + z_p=z_p, + logs_q=logs_q, + m_p=m_p, + logs_p=logs_p, + z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha + loss_feat = self.feature_loss( + feats_real=feats_disc_real, + feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration @@ -675,7 +681,7 @@ class VitsDiscriminatorLoss(nn.Module): def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) + loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss From 7dfd753d91689888e660b1ca596649faca530090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:34:17 +0100 Subject: [PATCH 162/214] Add a cheap trick to avoid short audio clips --- TTS/tts/datasets/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 499e6b7b..af726818 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -229,7 +229,8 @@ class TTSDataset(Dataset): # after phonemization the text length may change # this is a shareful 🤭 hack to prevent longer phonemes # TODO: find a better fix - if len(token_ids) > self.max_text_len: + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 return self.load_data(self.rescue_item_idx) # get f0 values From 1e219fef0a74e5a6e38aa795e09db3f6ac10513e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:34:39 +0100 Subject: [PATCH 163/214] Revert drop_last --- TTS/tts/models/base_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index ca3c4a28..becf22f9 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -326,7 +326,7 @@ class BaseTTS(BaseModel): batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, # shuffle is done in the dataset. collate_fn=dataset.collate_fn, - drop_last=True, # setting this False might cause issues in AMP training. + drop_last=False, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, From b0cff949f5a4a7907a03c46fba993e4b7643bce3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:31:45 +0100 Subject: [PATCH 164/214] Update tests --- tests/text_tests/test_characters.py | 7 +------ tests/text_tests/test_tokenizer.py | 1 + tests/tts_tests/test_glow_tts.py | 2 +- tests/vocoder_tests/test_multiband_melgan_train.py | 1 + 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 3f4086d5..5432c652 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,15 +1,10 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, create_graphemes, create_phonemes +from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes # pylint: disable=protected-access -def test_make_symbols(): - _ = create_phonemes() - _ = create_graphemes() - - class BaseCharacterTest(unittest.TestCase): def setUp(self): self.characters_empty = BaseCharacters("", "", pad="", eos="", bos="", blank="", is_unique=True, is_sorted=True) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 47174518..908952ea 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -64,6 +64,7 @@ class TestTTSTokenizer(unittest.TestCase): def test_init_from_config(self): @dataclass class Characters(Coqpit): + characters_class: str = None characters: str = _phonemes punctuations: str = _punctuations pad: str = _pad diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 305f86b8..85b5ed7a 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -11,7 +11,7 @@ from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor -from TTS.utils.logging.tensorboard_logger import TensorboardLogger +from trainer.logging.tensorboard_logger import TensorboardLogger # pylint: disable=unused-variable diff --git a/tests/vocoder_tests/test_multiband_melgan_train.py b/tests/vocoder_tests/test_multiband_melgan_train.py index c49107bd..80027607 100644 --- a/tests/vocoder_tests/test_multiband_melgan_train.py +++ b/tests/vocoder_tests/test_multiband_melgan_train.py @@ -20,6 +20,7 @@ config = MultibandMelganConfig( eval_split_size=1, print_step=1, print_eval=True, + steps_to_start_discriminator=1, data_path="tests/data/ljspeech", discriminator_model_params={"base_channels": 16, "max_channels": 64, "downsample_factors": [4, 4, 4]}, output_path=output_path, From c9117298960d83e1b1e04d23eca4cda1b1c8bc20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:32:28 +0100 Subject: [PATCH 165/214] Update BaseTrainerModel --- TTS/model.py | 93 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index 6ce11e63..d7bd4f9f 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,39 +1,28 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple -import numpy as np import torch from coqpit import Coqpit from torch import nn -# pylint: skip-file -class BaseModel(nn.Module, ABC): +class BaseTrainerModel(ABC, nn.Module): """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as - - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` """ - def __init__(self, config: Coqpit): - super().__init__() - @staticmethod + @abstractmethod def init_from_config(config: Coqpit): """Init the model from given config. Override this depending on your model. """ - pass + ... @abstractmethod def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: - """Forward pass for the model mainly used in training. + """Forward ... for the model mainly used in training. You can be flexible here and use different number of arguments and argument names since it is intended to be used by `train_step()` without exposing it out of the model. @@ -51,7 +40,7 @@ class BaseModel(nn.Module, ABC): @abstractmethod def inference(self, input: torch.Tensor, aux_input={}) -> Dict: - """Forward pass for inference. + """Forward ... for inference. We don't use `*kwargs` since it is problematic with the TorchScript API. @@ -66,9 +55,25 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict + def format_batch(self, batch: Dict) -> Dict: + """Format batch returned by the data loader before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + + def format_batch_on_device(self, batch:Dict) -> Dict: + """Format batch on device before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + @abstractmethod def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. + """Perform a single training step. Run the model forward ... and compute losses. Args: batch (Dict): Input tensors. @@ -96,11 +101,11 @@ class BaseModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - pass + ... @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can + """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can call `train_step()` with no changes. Args: @@ -117,45 +122,55 @@ class BaseModel(nn.Module, ABC): def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """The same as `train_log()`""" - pass + ... + @abstractmethod - def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: """Load a checkpoint and get ready for training or inference. Args: config (Coqpit): Model configuration. checkpoint_path (str): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. """ ... @staticmethod @abstractmethod - def init_from_config(config: Coqpit): + def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel": """Init the model from given config. Override this depending on your model. """ - pass + ... - def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - """Setup an return optimizer or optimizers.""" - pass + @abstractmethod + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: True, + data_items: List, + verbose: bool, + num_gpus: int): + ... - def get_lr(self) -> Union[float, List[float]]: - """Return learning rate(s). + # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: + # """Setup an return optimizer or optimizers.""" + # ... - Returns: - Union[float, List[float]]: Model's initial learning rates. - """ - pass + # def get_lr(self) -> Union[float, List[float]]: + # """Return learning rate(s). - def get_scheduler(self, optimizer: torch.optim.Optimizer): - pass + # Returns: + # Union[float, List[float]]: Model's initial learning rates. + # """ + # ... - def get_criterion(self): - pass + # def get_scheduler(self, optimizer: torch.optim.Optimizer): + # ... - def format_batch(self): - pass + # def get_criterion(self): + # ... From be3a03126ad4b2630c6006019dc28e3bcea6c994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:35:42 +0100 Subject: [PATCH 166/214] Update imports for trainer --- TTS/bin/distribute.py | 4 ++-- TTS/bin/train_encoder.py | 4 +++- TTS/bin/train_tts.py | 8 +++++++- TTS/bin/train_vocoder.py | 8 +++++++- TTS/server/server.py | 4 ++-- TTS/speaker_encoder/utils/training.py | 19 +++++++++++++------ 6 files changed, 34 insertions(+), 13 deletions(-) diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 06d5f388..40f60d5d 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -8,14 +8,14 @@ import time import torch -from TTS.trainer import TrainingArgs +from trainer import TrainerArgs def main(): """ Call train.py as a new process and pass command arguments """ - parser = TrainingArgs().init_argparse(arg_prefix="") + parser = TrainerArgs().init_argparse(arg_prefix="") parser.add_argument("--script", type=str, help="Target training script to distibute.") args, unargs = parser.parse_known_args() diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 8c364300..f19966ee 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -9,6 +9,8 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.torch import NoamLR + from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model @@ -19,7 +21,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.io import load_fsspec from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, check_update +from TTS.utils.training import check_update torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 73063731..467685b2 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from trainer import Trainer, TrainerArgs @@ -7,10 +8,15 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainerArgs() + train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 6d4df610..c52fd962 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from trainer import Trainer, TrainerArgs @@ -8,10 +9,15 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainerArgs() + train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/server/server.py b/TTS/server/server.py index f2512582..aef507fd 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path: if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) -# CASE3: set custome model paths +# CASE3: set custom model paths if args.model_path is not None: model_path = args.model_path config_path = args.config_path @@ -170,9 +170,9 @@ def tts(): text = request.args.get("text") speaker_idx = request.args.get("speaker_id", "") style_wav = request.args.get("style_wav", "") - style_wav = style_wav_uri_to_dict(style_wav) print(" > Model input: {}".format(text)) + print(" > Speaker Idx: {}".format(speaker_idx)) wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index b202ebcd..5c2de274 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,19 +1,26 @@ +from asyncio.log import logger +from dataclasses import dataclass, field import os from coqpit import Coqpit from TTS.config import load_config, register_config -from TTS.trainer import TrainingArgs +from trainer import TrainerArgs from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from TTS.utils.logging import init_dashboard_logger -from TTS.utils.logging.console_logger import ConsoleLogger +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger from TTS.utils.trainer_utils import get_last_checkpoint +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def getarguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser @@ -75,13 +82,13 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - dashboard_logger = init_dashboard_logger(config) + dashboard_logger = logger_factory(config, experiment_path) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger def init_arguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser From 20a677c6238384cd842c706bb0142f0751dcfc37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:36:27 +0100 Subject: [PATCH 167/214] Update test_run in wavernn and wavegrad --- TTS/vocoder/models/wavegrad.py | 7 ++++--- TTS/vocoder/models/wavernn.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 9d6e431c..58fc8762 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder): ) -> None: pass - def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument + def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) for sample in samples: x = sample[0] x = x[None, :, :].to(next(self.parameters()).device) @@ -307,12 +308,12 @@ class Wavegrad(BaseVocoder): return {"input": m, "waveform": y} def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int + self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int ): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, - items=data_items, + items=samples, seq_len=self.config.seq_len, hop_len=ap.hop_length, pad_short=self.config.pad_short, diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 68f9b2c8..6686db45 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -568,12 +568,13 @@ class Wavernn(BaseVocoder): return self.train_step(batch, criterion) @torch.no_grad() - def test_run( - self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: ap = assets["audio_processor"] figures = {} audios = {} + samples = test_loader.dataset.load_test_samples(1) for idx, sample in enumerate(samples): x = torch.FloatTensor(sample[0]) x = x.to(next(self.parameters()).device) @@ -600,14 +601,14 @@ class Wavernn(BaseVocoder): config: Coqpit, assets: Dict, is_eval: True, - data_items: List, + samples: List, verbose: bool, num_gpus: int, ): ap = assets["audio_processor"] dataset = WaveRNNDataset( ap=ap, - items=data_items, + items=samples, seq_len=config.seq_len, hop_len=ap.hop_length, pad=config.model_args.pad, From fc3b6d2861640403559d765c4fa7e6e0d0bd5df5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:37:34 +0100 Subject: [PATCH 168/214] Update gan --- TTS/vocoder/models/gan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 6978f0e7..d4abaa0a 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -80,8 +80,8 @@ class GAN(BaseVocoder): Returns: Tuple[Dict, Dict]: model outputs and the computed loss values. """ - outputs = None - loss_dict = None + outputs = {} + loss_dict = {} x = batch["input"] y = batch["waveform"] @@ -311,7 +311,7 @@ class GAN(BaseVocoder): config: Coqpit, assets: Dict, is_eval: True, - data_items: List, + samples: List, verbose: bool, num_gpus: int, rank: int = None, # pylint: disable=unused-argument @@ -322,7 +322,7 @@ class GAN(BaseVocoder): config (Coqpit): Model config. ap (AudioProcessor): Audio processor. is_eval (True): Set the dataloader for evaluation if true. - data_items (List): Data samples. + samples (List): Data samples. verbose (bool): Log information if true. num_gpus (int): Number of GPUs in use. rank (int): Rank of the current GPU. Defaults to None. @@ -332,7 +332,7 @@ class GAN(BaseVocoder): """ dataset = GANDataset( ap=self.ap, - items=data_items, + items=samples, seq_len=config.seq_len, hop_len=self.ap.hop_length, pad_short=config.pad_short, From 833de62e3059fdc76bbafd5afeb2d34793ac8736 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:38:24 +0100 Subject: [PATCH 169/214] Update base_vocoder --- TTS/vocoder/models/base_vocoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index 2728525c..01a7ff68 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -1,11 +1,11 @@ from coqpit import Coqpit -from TTS.model import BaseModel +from TTS.model import BaseTrainerModel # pylint: skip-file -class BaseVocoder(BaseModel): +class BaseVocoder(BaseTrainerModel): """Base `vocoder` class. Every new `vocoder` model must inherit this. It defines `vocoder` specific functions on top of `Model`. @@ -19,7 +19,7 @@ class BaseVocoder(BaseModel): """ def __init__(self, config): - super().__init__(config) + super().__init__() self._set_model_args(config) def _set_model_args(self, config: Coqpit): From 2bad09862572f33e1c7ab531899a64fbca594de3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:47:42 +0100 Subject: [PATCH 170/214] Implement BaseVocabulary --- TTS/tts/utils/text/characters.py | 109 +++++++++++++++++++++++++++---- TTS/utils/training.py | 14 ---- 2 files changed, 97 insertions(+), 26 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index aae6844f..f6c04370 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,4 +1,6 @@ +from abc import ABC from dataclasses import replace +from typing import Dict from TTS.tts.configs.shared_configs import CharactersConfig @@ -79,6 +81,71 @@ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprase # DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) +class BaseVocabulary: + """Base Vocabulary class. + + This class only needs a vocabulary dictionary without specifying the characters. + + Args: + vocab (Dict): A dictionary of characters and their corresponding indices. + """ + + def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None): + self.vocab = vocab + self.pad = pad + self.blank = blank + self.bos = bos + self.eos = eos + + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def vocab(self): + return self._vocab + + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } + + @staticmethod + def init_from_config(config, **kwargs): + if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: + return ( + BaseVocabulary( + config.characters.vocab_dict, + config.characters.pad, + config.characters.blank, + config.characters.bos, + config.characters.eos, + ), + config, + ) + return BaseVocabulary(**kwargs), config + + @property + def num_chars(self): + return max(self._vocab.values()) + 1 + + def char_to_id(self, char: str) -> int: + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + class BaseCharacters: """🐸BaseCharacters class @@ -116,12 +183,12 @@ class BaseCharacters: def __init__( self, - characters: str, - punctuations: str, - pad: str, - eos: str, - bos: str, - blank: str, + characters: str = None, + punctuations: str = None, + pad: str = None, + eos: str = None, + bos: str = None, + blank: str = None, is_unique: bool = False, is_sorted: bool = True, ) -> None: @@ -135,6 +202,14 @@ class BaseCharacters: self.is_sorted = is_sorted self._create_vocab() + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + @property def characters(self): return self._characters @@ -193,6 +268,14 @@ class BaseCharacters: def vocab(self): return self._vocab + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension + } + @property def num_chars(self): return len(self._vocab) @@ -208,11 +291,7 @@ class BaseCharacters: _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab - self._vocab = _vocab + list(self._punctuations) - self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension - } + self.vocab = _vocab + list(self._punctuations) if self.is_unique: duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} assert ( @@ -248,7 +327,13 @@ class BaseCharacters: Implement this method for your subclass. """ - ... + # use character set from config + if config.characters is not None: + return BaseCharacters(**config.characters), config + # return default character set + characters = BaseCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config def to_config(self) -> "CharactersConfig": return CharactersConfig( diff --git a/TTS/utils/training.py b/TTS/utils/training.py index 9f01b310..e69fb2b4 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -30,20 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): return grad_norm, skip_flag -# pylint: disable=protected-access -class NoamLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): - self.warmup_steps = float(warmup_steps) - super().__init__(optimizer, last_epoch) - - def get_lr(self): - step = max(self.last_epoch, 1) - return [ - base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) - for base_lr in self.base_lrs - ] - - def gradual_training_scheduler(global_step, config): """Setup the gradual training schedule wrt number of active GPUs""" From 35fc7270ff7a0a047562427411200c15f63d27e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:50:13 +0100 Subject: [PATCH 171/214] Implement BaseTTS --- TTS/tts/models/base_tts.py | 19 +++++++------------ TTS/tts/utils/text/tokenizer.py | 4 ++++ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index becf22f9..6dd7ca72 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -9,7 +9,7 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.model import BaseModel +from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler @@ -19,27 +19,22 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file -class BaseTTS(BaseModel): +class BaseTTS(BaseTrainerModel): """Base `tts` class. Every new `tts` model must inherit this. It defines common `tts` specific functions on top of `Model` implementation. - - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as - - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` """ def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None ): - super().__init__(config) + super().__init__() self.config = config self.ap = ap self.tokenizer = tokenizer self.speaker_manager = speaker_manager + self.language_manager = language_manager self._set_model_args(config) def _set_model_args(self, config: Coqpit): @@ -262,7 +257,7 @@ class BaseTTS(BaseModel): d_vector_mapping = None # setup multi-lingual attributes - if hasattr(self, "language_manager"): + if hasattr(self, "language_manager") and self.language_manager is not None: language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None ) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index bdaf8ea6..50a5f519 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -119,6 +119,10 @@ class TTSTokenizer: return [self.characters.bos] + list(char_sequence) + [self.characters.eos] def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ char_to_use = self.characters.blank if use_blank_char else self.characters.pad result = [char_to_use] * (len(char_sequence) * 2 + 1) result[1::2] = char_sequence From d0c27a9661f5b1b19d207391e13edc5ea742dda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:50:42 +0100 Subject: [PATCH 172/214] Update synthesis.py --- TTS/tts/utils/synthesis.py | 7 ++++--- TTS/utils/training.py | 28 ---------------------------- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index e2d9c113..377f32de 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -193,14 +193,15 @@ def synthesis( # convert outputs to numpy # plot results wav = None - if hasattr(model, "END2END") and model.END2END: - wav = model_outputs.squeeze(0) - else: + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] if use_griffin_lim: wav = inv_spectrogram(model_outputs, model.ap, CONFIG) # trim silence if do_trim_silence: wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs return_dict = { "wav": wav, "alignments": alignments, diff --git a/TTS/utils/training.py b/TTS/utils/training.py index e69fb2b4..b51f55e9 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -42,31 +42,3 @@ def gradual_training_scheduler(global_step, config): if global_step * num_gpus >= values[0]: new_values = values return new_values[1], new_values[2] - - -def lr_decay(init_lr, global_step, warmup_steps): - r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py - It is only being used by the Speaker Encoder trainer.""" - warmup_steps = float(warmup_steps) - step = global_step + 1.0 - lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) - return lr - - -# pylint: disable=dangerous-default-value -def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): - """ - Skip biases, BatchNorm parameters, rnns. - and attention projection layer v - """ - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): - no_decay.append(param) - else: - decay.append(param) - return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] From 935a60404629e63f96c4df77042b0e3908f2da48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:51:03 +0100 Subject: [PATCH 173/214] Delete trainer_utils --- TTS/utils/trainer_utils.py | 150 ------------------------------------- 1 file changed, 150 deletions(-) delete mode 100644 TTS/utils/trainer_utils.py diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py deleted file mode 100644 index dabb33cd..00000000 --- a/TTS/utils/trainer_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -import importlib -import os -import re -from typing import Dict, List, Tuple -from urllib.parse import urlparse - -import fsspec -import torch - -from TTS.utils.io import load_fsspec -from TTS.utils.training import NoamLR - - -def is_apex_available(): - return importlib.util.find_spec("apex") is not None - - -def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]: - """Setup PyTorch environment for training. - - Args: - cudnn_enable (bool): Enable/disable CUDNN. - cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is - variable between batches. - use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise. - - Returns: - Tuple[bool, int]: is cuda on or off and number of GPUs in the environment. - """ - num_gpus = torch.cuda.device_count() - if num_gpus > 1 and not use_ddp: - raise RuntimeError( - f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." - ) - torch.backends.cudnn.enabled = cudnn_enable - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.manual_seed(54321) - use_cuda = torch.cuda.is_available() - print(" > Using CUDA: ", use_cuda) - print(" > Number of GPUs: ", num_gpus) - return use_cuda, num_gpus - - -def get_scheduler( - lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer -) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access - """Find, initialize and return a scheduler. - - Args: - lr_scheduler (str): Scheduler name. - lr_scheduler_params (Dict): Scheduler parameters. - optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler. - - Returns: - torch.optim.lr_scheduler._LRScheduler: Functional scheduler. - """ - if lr_scheduler is None: - return None - if lr_scheduler.lower() == "noamlr": - scheduler = NoamLR - else: - scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler) - return scheduler(optimizer, **lr_scheduler_params) - - -def get_optimizer( - optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None -) -> torch.optim.Optimizer: - """Find, initialize and return a optimizer. - - Args: - optimizer_name (str): Optimizer name. - optimizer_params (dict): Optimizer parameters. - lr (float): Initial learning rate. - model (torch.nn.Module): Model to pass to the optimizer. - - Returns: - torch.optim.Optimizer: Functional optimizer. - """ - if optimizer_name.lower() == "radam": - module = importlib.import_module("TTS.utils.radam") - optimizer = getattr(module, "RAdam") - else: - optimizer = getattr(torch.optim, optimizer_name) - if model is not None: - parameters = model.parameters() - return optimizer(parameters, lr=lr, **optimizer_params) - - -def get_last_checkpoint(path: str) -> Tuple[str, str]: - """Get latest checkpoint or/and best model in path. - - It is based on globbing for `*.pth.tar` and the RegEx - `(checkpoint|best_model)_([0-9]+)`. - - Args: - path: Path to files to be compared. - - Raises: - ValueError: If no checkpoint or best_model files are found. - - Returns: - Path to the last checkpoint - Path to best checkpoint - """ - fs = fsspec.get_mapper(path).fs - file_names = fs.glob(os.path.join(path, "*.pth.tar")) - scheme = urlparse(path).scheme - if scheme: # scheme is not preserved in fs.glob, add it back - file_names = [scheme + "://" + file_name for file_name in file_names] - last_models = {} - last_model_nums = {} - for key in ["checkpoint", "best_model"]: - last_model_num = None - last_model = None - # pass all the checkpoint files and find - # the one with the largest model number suffix. - for file_name in file_names: - match = re.search(f"{key}_([0-9]+)", file_name) - if match is not None: - model_num = int(match.groups()[0]) - if last_model_num is None or model_num > last_model_num: - last_model_num = model_num - last_model = file_name - - # if there is no checkpoint found above - # find the checkpoint with the latest - # modification date. - key_file_names = [fn for fn in file_names if key in fn] - if last_model is None and len(key_file_names) > 0: - last_model = max(key_file_names, key=os.path.getctime) - last_model_num = load_fsspec(last_model)["step"] - - if last_model is not None: - last_models[key] = last_model - last_model_nums[key] = last_model_num - - # check what models were found - if not last_models: - raise ValueError(f"No models found in continue path {path}!") - if "checkpoint" not in last_models: # no checkpoint just best model - last_models["checkpoint"] = last_models["best_model"] - elif "best_model" not in last_models: # no best model - # this shouldn't happen, but let's handle it just in case - last_models["best_model"] = last_models["checkpoint"] - # finally check if last best model is more recent than checkpoint - elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: - last_models["checkpoint"] = last_models["best_model"] - - return last_models["checkpoint"], last_models["best_model"] From 00c7600103ee34ac50506af88f1b34b713f849e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:51:37 +0100 Subject: [PATCH 174/214] Update Vits model API --- TTS/tts/models/vits.py | 831 ++++++++++++++++++++++++++++++----------- 1 file changed, 617 insertions(+), 214 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7dac1bb9..b7766e92 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,24 +1,31 @@ +import collections import math +import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist import torchaudio from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -27,6 +34,263 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +############################## +# IO / Feature extraction +############################## + +hann_window = {} +mel_basis = {} + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +class VitsDataset(TTSDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + + def __getitem__(self, idx): + item = self.samples[idx] + + text, wav_file, speaker_name, language_name, _ = _parse_sample(item) + raw_text = text + + wav, sr = load_audio(wav_file) + wav_filename = os.path.basename(wav_file) + + token_ids = self.get_token_ids(idx, text) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": speaker_name, + "language_name": language_name, + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# MODEL DEFINITION +############################## + @dataclass class VitsArgs(Coqpit): @@ -268,38 +532,20 @@ class Vits(BaseTTS): Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. Examples: - Init only model layers. - >>> from TTS.tts.configs.vits_config import VitsConfig >>> from TTS.tts.models.vits import Vits >>> config = VitsConfig() >>> model = Vits(config) - - Fully init a model ready for action. All the class attributes and class members - (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. - - >>> from TTS.tts.configs.vits_config import VitsConfig - >>> from TTS.tts.models.vits import Vits - >>> config = VitsConfig() - >>> model = Vits.init_from_config(config) """ - # pylint: disable=dangerous-default-value - - def __init__( - self, + def __init__(self, config: Coqpit, ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + language_manager: LanguageManager = None,): - super().__init__(config, ap, tokenizer, speaker_manager) - - self.END2END = True - self.speaker_manager = speaker_manager - self.language_manager = language_manager + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) self.init_multispeaker(config) self.init_multilingual(config) @@ -363,10 +609,6 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, ) - upsample_rate = math.prod(self.args.upsample_rates_decoder) - assert ( - upsample_rate == self.config.audio.hop_length - ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -398,6 +640,7 @@ class Vits(BaseTTS): """ self.embedded_speaker_dim = 0 self.num_speakers = self.args.num_speakers + self.audio_transform = None if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers @@ -428,8 +671,11 @@ class Vits(BaseTTS): orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], ) - else: - self.audio_transform = None + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -463,6 +709,35 @@ class Vits(BaseTTS): self.embedded_language_dim = 0 self.emb_l = None + def get_aux_input(self, aux_input: Dict): + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" @@ -483,58 +758,6 @@ class Vits(BaseTTS): return sid, g, lid - def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) - return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} - - def get_aux_input_from_test_sentences(self, sentence_info): - if hasattr(self.config, "model_args"): - config = self.config.model_args - else: - config = self.config - - # extract speaker and language info - text, speaker_name, style_wav, language_name = None, None, None, None - - if isinstance(sentence_info, list): - if len(sentence_info) == 1: - text = sentence_info[0] - elif len(sentence_info) == 2: - text, speaker_name = sentence_info - elif len(sentence_info) == 3: - text, speaker_name, style_wav = sentence_info - elif len(sentence_info) == 4: - text, speaker_name, style_wav, language_name = sentence_info - else: - text = sentence_info - - # get speaker id/d_vector - speaker_id, d_vector, language_id = None, None, None - if hasattr(self, "speaker_manager"): - if config.use_d_vector_file: - if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() - else: - d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) - elif config.use_speaker_embedding: - if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() - else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] - - # get language id - if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] - - return { - "text": text, - "speaker_id": speaker_id, - "style_wav": style_wav, - "d_vector": d_vector, - "language_id": language_id, - "language_name": language_name, - } - def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) speaker_ids = aux_input.get("speaker_ids", None) @@ -611,7 +834,7 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - - waveform: :math:`[B, T_wav, 1]` + - waveform: :math:`[B, 1, T_wav]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` @@ -656,13 +879,14 @@ class Vits(BaseTTS): logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, + pad_short = True ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -694,6 +918,7 @@ class Vits(BaseTTS): "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, } ) return outputs @@ -798,30 +1023,6 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def _freeze_layers(self): - if self.args.freeze_encoder: - for param in self.text_encoder.parameters(): - param.requires_grad = False - - if hasattr(self, "emb_l"): - for param in self.emb_l.parameters(): - param.requires_grad = False - - if self.args.freeze_PE: - for param in self.posterior_encoder.parameters(): - param.requires_grad = False - - if self.args.freeze_DP: - for param in self.duration_predictor.parameters(): - param.requires_grad = False - - if self.args.freeze_flow_decoder: - for param in self.flow.parameters(): - param.requires_grad = False - - if self.args.freeze_waveform_decoder: - for param in self.waveform_decoder.parameters(): - param.requires_grad = False def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. @@ -835,91 +1036,101 @@ class Vits(BaseTTS): Tuple[Dict, Dict]: Model ouputs and computed losses. """ - # pylint: disable=attribute-defined-outside-init - if optimizer_idx not in [0, 1]: - raise ValueError(" [!] Unexpected `optimizer_idx`.") - self._freeze_layers() + mel_lens = batch["mel_lens"] + if optimizer_idx == 0: - text_input = batch["text_input"] - text_lengths = batch["text_lengths"] - mel_lengths = batch["mel_lengths"] - linear_input = batch["linear_input"] + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + spec_lens = batch["spec_lens"] + d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] language_ids = batch["language_ids"] waveform = batch["waveform"] - # if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0: - # breakpoint() - # generator pass outputs = self.forward( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - waveform.transpose(1, 2), + tokens, + token_lenghts, + spec, + spec_lens, + waveform, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) - # cache tensors for the discriminator - self.y_disc_cache = outputs["model_outputs"] - self.wav_seg_disc_cache = outputs["waveform_seg"] - - # compute discriminator scores and features - outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( - outputs["model_outputs"], outputs["waveform_seg"] - ) - - # compute losses - with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - waveform_hat=outputs["model_outputs"].float(), - waveform=outputs["waveform_seg"].float(), - z_p=outputs["z_p"].float(), - logs_q=outputs["logs_q"].float(), - m_p=outputs["m_p"].float(), - logs_p=outputs["logs_p"].float(), - z_len=mel_lengths, - scores_disc_fake=outputs["scores_disc_fake"], - feats_disc_fake=outputs["feats_disc_fake"], - feats_disc_real=outputs["feats_disc_real"], - loss_duration=outputs["loss_duration"], - use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb=outputs["gt_spk_emb"], - syn_spk_emb=outputs["syn_spk_emb"], - ) - - # if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0: - # breakpoint() - - elif optimizer_idx == 1: - # discriminator pass - outputs = {} + # cache tensors for the generator pass + self.model_outputs_cache = outputs # compute scores and features - outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc( - self.y_disc_cache.detach(), self.wav_seg_disc_cache + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] ) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( - outputs["scores_disc_real"], - outputs["scores_disc_fake"], + scores_disc_real, + scores_disc_fake, ) - return outputs, loss_dict + return {}, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True) + mel_slice_hat = wav_to_mel( + y = self.model_outputs_cache["model_outputs"].float(), + n_fft = self.config.audio.fft_size, + sample_rate = self.config.audio.sample_rate, + num_mels = self.config.audio.num_mels, + hop_length = self.config.audio.hop_length, + win_length = self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p= self.model_outputs_cache["z_p"].float(), + logs_q= self.model_outputs_cache["logs_q"].float(), + m_p= self.model_outputs_cache["m_p"].float(), + logs_p= self.model_outputs_cache["logs_p"].float(), + z_len=mel_lens, + scores_disc_fake= scores_disc_fake, + feats_disc_fake= feats_disc_fake, + feats_disc_real= feats_disc_real, + loss_duration= self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb= self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb= self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use - y_hat = outputs[0]["model_outputs"] - y = outputs[0]["waveform_seg"] + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name_prefix}/audio": sample_voice} - alignments = outputs[0]["alignments"] + alignments = outputs[1]["alignments"] align_img = alignments[0].data.cpu().numpy().T figures.update( @@ -927,7 +1138,6 @@ class Vits(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } ) - return figures, audios def train_log( @@ -948,7 +1158,7 @@ class Vits(BaseTTS): """ figures, audios = self._log(self.ap, batch, outputs, "train") logger.train_figures(steps, figures) - logger.train_figures(steps, audios, self.ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): @@ -959,6 +1169,54 @@ class Vits(BaseTTS): logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. @@ -973,56 +1231,187 @@ class Vits(BaseTTS): test_figures = {} test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): - try: - aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( - self, - aux_inputs["text"], - self.config, - "cuda" in str(next(self.parameters()).device), - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - language_id=aux_inputs["language_id"], - use_griffin_lim=True, - do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) return {"figures": test_figures, "audios": test_audios} def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_figures(steps, outputs["figures"]) + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.d_vectors + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding: + language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + # compute spectrograms + batch["spec"] = wav_to_spec( + batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False + ) + batch["mel"] = spec_to_mel( + spec = batch["spec"], + n_fft = ac.fft_size, + num_mels = ac.num_mels, + sample_rate = ac.sample_rate, + fmin = ac.mel_fmin, + fmax = ac.mel_fmax, + ) + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + speaker_id_mapping = None + d_vector_mapping = None + if hasattr(self, "speaker_manager") and self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None + + # setup multi-lingual attributes + language_id_mapping = None + if hasattr(self, "language_manager"): + language_id_mapping = ( + self.language_manager.language_id_mapping if self.args.use_language_embedding else None + ) + + # init dataloader + dataset = VitsDataset( + samples=samples, + # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # sampler for DDP + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + + # Weighted samplers + # TODO: make this DDP amenable + assert not ( + num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) + ), "language_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) + ), "speaker_weighted_sampler is not supported with DistributedSampler" + + if sampler is None: + if getattr(config, "use_language_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_language_weighted_sampler(dataset.samples) + elif getattr(config, "use_speaker_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_speaker_weighted_sampler(dataset.samples) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. - Returns: List: optimizers. """ - gen_parameters = chain( - self.text_encoder.parameters(), - self.posterior_encoder.parameters(), - self.flow.parameters(), - self.duration_predictor.parameters(), - self.waveform_decoder.parameters(), - ) - # add the speaker embedding layer - if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file: - gen_parameters = chain(gen_parameters, self.emb_g.parameters()) - # add the language embedding layer - if hasattr(self, "emb_l") and self.args.use_language_embedding: - gen_parameters = chain(gen_parameters, self.emb_l.parameters()) + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - optimizer0 = get_optimizer( + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) - optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) return [optimizer0, optimizer1] def get_lr(self) -> List: @@ -1031,7 +1420,7 @@ class Vits(BaseTTS): Returns: List: learning rates for each optimizer. """ - return [self.config.lr_gen, self.config.lr_disc] + return [self.config.lr_disc, self.config.lr_gen] def get_scheduler(self, optimizer) -> List: """Set the schedulers for each optimizer. @@ -1042,9 +1431,9 @@ class Vits(BaseTTS): Returns: List: Schedulers, one for each optimizer. """ - scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler0, scheduler1] + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] def get_criterion(self): """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in @@ -1054,10 +1443,14 @@ class Vits(BaseTTS): VitsGeneratorLoss, ) - return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] def load_checkpoint( - self, config, checkpoint_path, eval=False + self, + config, + checkpoint_path, + eval=False, + strict=True, ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = torch.load(checkpoint_path, map_location=torch.device("cpu")) @@ -1066,15 +1459,16 @@ class Vits(BaseTTS): # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers - if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: - num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0] print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") - emb_g = state["model"]["emb_g.weight"] + emb_g = state["model"]["vits.emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) - state["model"]["emb_g.weight"] = emb_g + state["model"]["vits.emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) - self.load_state_dict(state["model"], strict=False) if eval: self.eval() assert not self.training @@ -1090,12 +1484,21 @@ class Vits(BaseTTS): """ from TTS.utils.audio import AudioProcessor + upsample_rate = math.prod(config.model_args.upsample_rates_decoder) + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) +################################## +# VITS CHARACTERS +################################## + class VitsCharacters(BaseCharacters): """Characters class for VITs model for compatibility with pre-trained models""" From c11944022d35abeba070669a5db6258e76cdfcfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:51:56 +0100 Subject: [PATCH 175/214] Revert back again rand_segment --- TTS/tts/utils/helpers.py | 52 +++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 9ccb5d62..1366c4a6 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): """Segment each sample in a batch based on the provided segment indices Args: @@ -66,16 +66,25 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): segment_size (int): Expected output segment size. pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ - ret = torch.zeros_like(x[:, :, :segment_size]) + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): - idx_str = segment_indices[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + index_start = segment_indices[i] + index_end = index_start + segment_size + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] + return segments def rand_segments( - x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False ): """Create random segments based on the input lengths. @@ -90,16 +99,25 @@ def rand_segments( - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ - b, _, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - if (ids_str_max < 0).sum(): - raise ValueError("Segment size is larger than the input length.") - ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long() - ret = segment(x, ids_str, segment_size) - return ret, ids_str - + _x_lenghts = x_lengths.clone() + B, _, T = x.size() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices def average_over_durations(values, durs): """Average values over durations. From c68962c57409f34c8a88a0163f5f476e3b87d2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:53:44 +0100 Subject: [PATCH 176/214] Update forward tts binary loss --- TTS/tts/configs/fast_pitch_config.py | 3 +++ TTS/tts/configs/fast_speech_config.py | 6 +++--- TTS/tts/configs/speedy_speech_config.py | 6 +++--- TTS/tts/models/forward_tts.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index de870388..024040f8 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -92,6 +92,9 @@ class FastPitchConfig(BaseTTSConfig): binary_align_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + min_seq_len (int): Minimum input sequence length to be used at training. diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py index 31d99442..f0c23593 100644 --- a/TTS/tts/configs/fast_speech_config.py +++ b/TTS/tts/configs/fast_speech_config.py @@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig): pitch_loss_alpha: float = 0.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 20000 + binary_align_loss_start_step: int = 50000 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index ea6866ed..4bf5101f 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig): spec_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 0.3 - binary_align_loss_start_step: int = 50000 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 8d554f76..db8fef2d 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -178,8 +178,8 @@ class ForwardTTS(BaseTTS): tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): - super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) self.init_multispeaker(config) From 52a7896668e0ac1bd8b177688b8a68bc660bfacc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:54:05 +0100 Subject: [PATCH 177/214] Update VITS loss --- TTS/tts/layers/losses.py | 42 +++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0c94f91f..57d36717 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -587,13 +587,12 @@ class VitsGeneratorLoss(nn.Module): @staticmethod def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): - l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() - return l + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() def forward( self, - waveform, - waveform_hat, + mel_slice, + mel_slice_hat, z_p, logs_q, m_p, @@ -609,8 +608,8 @@ class VitsGeneratorLoss(nn.Module): ): """ Shapes: - - waveform : :math:`[B, 1, T]` - - waveform_hat: :math:`[B, 1, T]` + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` - m_p: :math:`[B, C, T]` @@ -623,30 +622,23 @@ class VitsGeneratorLoss(nn.Module): loss = 0.0 return_dict = {} z_mask = sequence_mask(z_len).float() - # compute mel spectrograms from the waveforms - mel = self.stft(waveform) - mel_hat = self.stft(waveform_hat) - # compute losses - loss_kl = self.kl_loss( - z_p=z_p, - logs_q=logs_q, - m_p=m_p, - logs_p=logs_p, - z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha - loss_feat = self.feature_loss( - feats_real=feats_disc_real, - feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha - loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha - loss += loss_se + loss = loss + loss_se return_dict["loss_spk_encoder"] = loss_se - # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl @@ -675,16 +667,18 @@ class VitsDiscriminatorLoss(nn.Module): loss += real_loss + fake_loss real_losses.append(real_loss.item()) fake_losses.append(fake_loss.item()) - return loss, real_losses, fake_losses def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) + loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr return return_dict From 750903d2bac5030bb21526e54f4dab23857833f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:54:53 +0100 Subject: [PATCH 178/214] Add VCTK formatter docstring --- TTS/tts/datasets/formatters.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 5a38039b..4592ccce 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -289,8 +289,27 @@ def brspeech(root_path, meta_file, ignored_speakers=None): return items -def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): - """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ file_ext = "flac" items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) From ff23dce081cf07ecf5f7b66bdefdec43195bba84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:55:40 +0100 Subject: [PATCH 179/214] Update TTSDataset --- TTS/tts/datasets/dataset.py | 95 +++++-------------------------------- 1 file changed, 13 insertions(+), 82 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index af726818..865209c2 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -37,10 +37,10 @@ def noise_augment_audio(wav): class TTSDataset(Dataset): def __init__( self, - outputs_per_step: int, - compute_linear_spec: bool, - ap: AudioProcessor, - samples: List[Dict], + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: List[Dict] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, @@ -118,7 +118,6 @@ class TTSDataset(Dataset): self.batch_group_size = batch_group_size self._samples = samples self.outputs_per_step = outputs_per_step - self.sample_rate = ap.sample_rate self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 @@ -153,6 +152,15 @@ class TTSDataset(Dataset): if self.verbose: self.print_logs() + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + @property def samples(self): return self._samples @@ -763,80 +771,3 @@ class F0Dataset: print(f"{indent}| > Number of instances : {len(self.samples)}") -# if __name__ == "__main__": -# from torch.utils.data import DataLoader - -# from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig -# from TTS.tts.datasets import load_tts_samples -# from TTS.tts.utils.text.characters import IPAPhonemes -# from TTS.tts.utils.text.phonemizers import ESpeak - -# dataset_config = BaseDatasetConfig( -# name="ljspeech", -# meta_file_train="metadata.csv", -# path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1", -# ) -# train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# samples = train_samples + eval_samples - -# phonemizer = ESpeak(language="en-us") -# tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer) -# # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests") -# # ph_dataset.precompute(num_workers=4) - -# # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn) -# # for batch in dataloader: -# # print(batch) -# # break - -# audio_config = BaseAudioConfig( -# sample_rate=22050, -# win_length=1024, -# hop_length=256, -# num_mels=80, -# preemphasis=0.0, -# ref_level_db=20, -# log_func="np.log", -# do_trim_silence=True, -# trim_db=45, -# mel_fmin=0, -# mel_fmax=8000, -# spec_gain=1.0, -# signal_norm=False, -# do_amp_to_db_linear=False, -# ) - -# ap = AudioProcessor.init_from_config(audio_config) - -# # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4) - -# # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn) -# # for batch in dataloader: -# # print(batch) -# # breakpoint() -# # break - -# dataset = TTSDataset( -# outputs_per_step=1, -# compute_linear_spec=False, -# samples=samples, -# ap=ap, -# return_wav=False, -# batch_group_size=0, -# min_seq_len=0, -# max_seq_len=500, -# use_noise_augment=False, -# verbose=True, -# speaker_id_mapping=None, -# d_vector_mapping=None, -# compute_f0=True, -# f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests", -# tokenizer=tokenizer, -# phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests", -# precompute_num_workers=4, -# ) - -# dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) -# for batch in dataloader: -# print(batch) -# break From 8b3ba02c953cf66b4263330b2f96f89b9d7fa299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:56:02 +0100 Subject: [PATCH 180/214] Add vocab_dict to model config --- TTS/tts/configs/shared_configs.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 96cf0427..f43c6464 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import List +from typing import Dict, List from coqpit import Coqpit, check_argument @@ -50,13 +50,16 @@ class GSTConfig(Coqpit): @dataclass class CharactersConfig(Coqpit): - """Defines arguments for the `BaseCharacters` and its subclasses. + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. Args: characters_class (str): Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on the configuration. Defaults to None. + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + pad (str): characters in place of empty padding. Defaults to None. @@ -89,6 +92,11 @@ class CharactersConfig(Coqpit): """ characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters pad: str = None eos: str = None bos: str = None From c0b40a0cb76b3d7668521537433e944c78ad5bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:56:21 +0100 Subject: [PATCH 181/214] Update VITS tests --- tests/tts_tests/test_vits.py | 114 ++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 43 deletions(-) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index eaa325b0..4018c6bd 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -3,17 +3,19 @@ import os import unittest import torch +from TTS.tts.datasets.formatters import ljspeech from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.logging.tensorboard_logger import TensorboardLogger +from trainer.logging.tensorboard_logger import TensorboardLogger LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") torch.manual_seed(1) @@ -23,6 +25,28 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # pylint: disable=no-self-use class TestVits(unittest.TestCase): + def test_load_audio(self): + wav, sr = load_audio(WAV_FILE) + self.assertEqual(wav.shape, (1, 41885)) + self.assertEqual(sr, 22050) + + spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) + mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False) + mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) + + self.assertEqual((mel - mel2).abs().max(), 0) + self.assertEqual(spec.shape[0], mel.shape[0]) + self.assertEqual(spec.shape[2], mel.shape[2]) + + spec_db = amp_to_db(spec) + spec_amp = db_to_amp(spec_db) + + self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4) + + def test_dataset(self): + """TODO:""" + ... + def test_init_multispeaker(self): num_speakers = 10 args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) @@ -107,10 +131,11 @@ class TestVits(unittest.TestCase): input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device) + mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device) spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device) - return input_dummy, input_lengths, spec, spec_lengths, waveform + return input_dummy, input_lengths, mel, spec, spec_lengths, waveform def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2): self.assertEqual( @@ -139,7 +164,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -150,7 +175,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -171,7 +196,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.train() - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} @@ -186,7 +211,7 @@ class TestVits(unittest.TestCase): args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -221,7 +246,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -330,20 +355,25 @@ class TestVits(unittest.TestCase): @staticmethod def _check_parameter_changes(model, model_ref): count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): + for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()): + name = item1[0] + param = item1[1] + param_ref = item2[1] assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref + name, param.shape, param, param_ref ) - count += 1 + count = count + 1 def _create_batch(self, config, batch_size): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size) + input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size) batch = {} - batch["text_input"] = input_dummy - batch["text_lengths"] = input_lengths - batch["mel_lengths"] = mel_lengths - batch["linear_input"] = mel_spec.transpose(1, 2) - batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device) + batch["tokens"] = input_dummy + batch["token_lens"] = input_lengths + batch["spec_lens"] = mel_lengths + batch["mel_lens"] = mel_lengths + batch["spec"] = spec + batch["mel"] = mel + batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device) batch["d_vectors"] = None batch["speaker_ids"] = None batch["language_ids"] = None @@ -351,33 +381,31 @@ class TestVits(unittest.TestCase): def test_train_step(self): # setup the model - config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) - model = Vits(config).to(device) - # create a batch - batch = self._create_batch(config, 1) - # model to train - criterions = model.get_criterion() - criterions = [criterions[0].to(device), criterions[1].to(device)] - # reference model to compare model weights - model_ref = Vits(config).to(device) - model.train() - # pass the state to ref model - model_ref.load_state_dict(copy.deepcopy(model.state_dict())) - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizers = model.get_optimizer() - for _ in range(5): - _, loss_dict = model.train_step(batch, criterions, 0) - loss = loss_dict["loss"] - loss.backward() - optimizers[0].step() + with torch.autograd.set_detect_anomaly(True): + + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + _, loss_dict = model.train_step(batch, criterions, idx) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() - _, loss_dict = model.train_step(batch, criterions, 1) - loss = loss_dict["loss"] - loss.backward() - optimizers[1].step() # check parameter changes self._check_parameter_changes(model, model_ref) From fc8264d9d24b20755e65f3ba745656c3c4e335e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 12:08:37 +0100 Subject: [PATCH 182/214] Update requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54a6bdfd..c60e0817 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ matplotlib tensorboardX pyworld # coqui stack -git+https://github.com/coqui-ai/Trainer@main # trainer +trainer @ git+https://github.com/coqui-ai/trainer.git coqpit # config managemenr # chinese g2p deps jieba From 424d04e4f6dfca0cb34c7957a970734304cdfb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 12:37:27 +0100 Subject: [PATCH 183/214] Make stlye --- TTS/bin/distribute.py | 1 - TTS/bin/train_encoder.py | 1 - TTS/bin/train_tts.py | 2 +- TTS/bin/train_vocoder.py | 2 +- TTS/model.py | 16 +-- TTS/speaker_encoder/utils/training.py | 9 +- TTS/tts/datasets/dataset.py | 2 - TTS/tts/layers/losses.py | 4 +- TTS/tts/models/base_tts.py | 8 +- TTS/tts/models/glow_tts.py | 3 +- TTS/tts/models/vits.py | 121 ++++++++---------- TTS/tts/utils/helpers.py | 1 + TTS/tts/utils/text/characters.py | 1 - TTS/tts/utils/text/punctuation.py | 2 +- TTS/vocoder/models/wavegrad.py | 6 +- .../multilingual/vits_tts/train_vits_tts.py | 4 +- tests/tts_tests/test_glow_tts.py | 2 +- tests/tts_tests/test_vits.py | 29 +++-- 18 files changed, 103 insertions(+), 111 deletions(-) diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 40f60d5d..97e2f0e3 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -7,7 +7,6 @@ import subprocess import time import torch - from trainer import TrainerArgs diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index f19966ee..5828411c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,7 +8,6 @@ import traceback import torch from torch.utils.data import DataLoader - from trainer.torch import NoamLR from TTS.speaker_encoder.dataset import SpeakerEncoderDataset diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 467685b2..31813712 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index c52fd962..32ecd7bd 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs diff --git a/TTS/model.py b/TTS/model.py index d7bd4f9f..39cbeabc 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -5,11 +5,11 @@ import torch from coqpit import Coqpit from torch import nn +# pylint: skip-file class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - """ + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @staticmethod @abstractmethod @@ -63,7 +63,7 @@ class BaseTrainerModel(ABC, nn.Module): """ return batch - def format_batch_on_device(self, batch:Dict) -> Dict: + def format_batch_on_device(self, batch: Dict) -> Dict: """Format batch on device before sending it to the model. If not implemented, model uses the batch as is. @@ -124,7 +124,6 @@ class BaseTrainerModel(ABC, nn.Module): """The same as `train_log()`""" ... - @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: """Load a checkpoint and get ready for training or inference. @@ -148,13 +147,8 @@ class BaseTrainerModel(ABC, nn.Module): @abstractmethod def get_data_loader( - self, - config: Coqpit, - assets: Dict, - is_eval: True, - data_items: List, - verbose: bool, - num_gpus: int): + self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int + ): ... # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index 5c2de274..c64c46b7 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,16 +1,15 @@ -from asyncio.log import logger -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from coqpit import Coqpit +from trainer import TrainerArgs +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config -from trainer import TrainerArgs from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer.logging import logger_factory -from trainer.logging.console_logger import ConsoleLogger from TTS.utils.trainer_utils import get_last_checkpoint diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 865209c2..d4d1a7e5 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -769,5 +769,3 @@ class F0Dataset: print("\n") print(f"{indent}> F0Dataset ") print(f"{indent}| > Number of instances : {len(self.samples)}") - - diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 57d36717..e03cf084 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -672,7 +672,9 @@ class VitsDiscriminatorLoss(nn.Module): def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 6dd7ca72..dd6539a5 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -26,8 +26,12 @@ class BaseTTS(BaseTrainerModel): """ def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, ): super().__init__() self.config = config diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 23eb48da..c30f043a 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -530,7 +530,8 @@ class GlowTTS(BaseTTS): self.store_inverse() assert not self.training - def get_criterion(self): + @staticmethod + def get_criterion(): from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel return GlowTTSLoss() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b7766e92..ec6c9e5b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,9 +1,8 @@ -import collections import math import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -25,7 +24,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -38,6 +37,7 @@ from TTS.vocoder.utils.generic_utils import plot_results # IO / Feature extraction ############################## +# pylint: disable=global-statement hann_window = {} mel_basis = {} @@ -200,7 +200,7 @@ class VitsDataset(TTSDataset): text, wav_file, speaker_name, language_name, _ = _parse_sample(item) raw_text = text - wav, sr = load_audio(wav_file) + wav, _ = load_audio(wav_file) wav_filename = os.path.basename(wav_file) token_ids = self.get_token_ids(idx, text) @@ -538,12 +538,14 @@ class Vits(BaseTTS): >>> model = Vits(config) """ - def __init__(self, + def __init__( + self, config: Coqpit, ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None,): + language_manager: LanguageManager = None, + ): super().__init__(config, ap, tokenizer, speaker_manager, language_manager) @@ -673,9 +675,9 @@ class Vits(BaseTTS): ) # pylint: disable=W0101,W0105 self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.config.audio.sample_rate, - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -777,9 +779,9 @@ class Vits(BaseTTS): with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] @@ -806,7 +808,7 @@ class Vits(BaseTTS): outputs["loss_duration"] = loss_duration return outputs, attn - def forward( + def forward( # pylint: disable=dangerous-default-value self, x: torch.tensor, x_lengths: torch.tensor, @@ -886,7 +888,7 @@ class Vits(BaseTTS): waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short = True + pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -929,7 +931,9 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}): + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + ): # pylint: disable=dangerous-default-value """ Note: To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. @@ -1023,7 +1027,6 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. @@ -1062,7 +1065,7 @@ class Vits(BaseTTS): ) # cache tensors for the generator pass - self.model_outputs_cache = outputs + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init # compute scores and features scores_disc_fake, _, scores_disc_real, _ = self.disc( @@ -1082,14 +1085,16 @@ class Vits(BaseTTS): # compute melspec segment with autocast(enabled=False): - mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True) + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True + ) mel_slice_hat = wav_to_mel( - y = self.model_outputs_cache["model_outputs"].float(), - n_fft = self.config.audio.fft_size, - sample_rate = self.config.audio.sample_rate, - num_mels = self.config.audio.num_mels, - hop_length = self.config.audio.hop_length, - win_length = self.config.audio.win_length, + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, fmin=self.config.audio.mel_fmin, fmax=self.config.audio.mel_fmax, center=False, @@ -1097,7 +1102,7 @@ class Vits(BaseTTS): # compute discriminator scores and features scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( - self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] ) # compute losses @@ -1105,18 +1110,18 @@ class Vits(BaseTTS): loss_dict = criterion[optimizer_idx]( mel_slice_hat=mel_slice.float(), mel_slice=mel_slice_hat.float(), - z_p= self.model_outputs_cache["z_p"].float(), - logs_q= self.model_outputs_cache["logs_q"].float(), - m_p= self.model_outputs_cache["m_p"].float(), - logs_p= self.model_outputs_cache["logs_p"].float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), z_len=mel_lens, - scores_disc_fake= scores_disc_fake, - feats_disc_fake= feats_disc_fake, - feats_disc_real= feats_disc_real, - loss_duration= self.model_outputs_cache["loss_duration"], + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb= self.model_outputs_cache["gt_spk_emb"], - syn_spk_emb= self.model_outputs_cache["syn_spk_emb"], + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], ) return self.model_outputs_cache, loss_dict @@ -1248,7 +1253,9 @@ class Vits(BaseTTS): test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) return {"figures": test_figures, "audios": test_audios} - def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_figures(steps, outputs["figures"]) @@ -1273,7 +1280,11 @@ class Vits(BaseTTS): d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding: + if ( + self.language_manager is not None + and self.language_manager.language_id_mapping + and self.args.use_language_embedding + ): language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] if language_ids is not None: @@ -1289,16 +1300,14 @@ class Vits(BaseTTS): ac = self.config.audio # compute spectrograms - batch["spec"] = wav_to_spec( - batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False - ) + batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) batch["mel"] = spec_to_mel( - spec = batch["spec"], - n_fft = ac.fft_size, - num_mels = ac.num_mels, - sample_rate = ac.sample_rate, - fmin = ac.mel_fmin, - fmax = ac.mel_fmax, + spec=batch["spec"], + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, ) assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" @@ -1325,27 +1334,6 @@ class Vits(BaseTTS): if is_eval and not config.run_eval: loader = None else: - # setup multi-speaker attributes - speaker_id_mapping = None - d_vector_mapping = None - if hasattr(self, "speaker_manager") and self.speaker_manager is not None: - if hasattr(config, "model_args"): - speaker_id_mapping = ( - self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None - ) - d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None - config.use_d_vector_file = config.model_args.use_d_vector_file - else: - speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None - - # setup multi-lingual attributes - language_id_mapping = None - if hasattr(self, "language_manager"): - language_id_mapping = ( - self.language_manager.language_id_mapping if self.args.use_language_embedding else None - ) - # init dataloader dataset = VitsDataset( samples=samples, @@ -1495,6 +1483,7 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 1366c4a6..c2e7f561 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -119,6 +119,7 @@ def rand_segments( ret = segment(x, segment_indices, segment_size) return ret, segment_indices + def average_over_durations(values, durs): """Average values over durations. diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f6c04370..0ce65a90 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,4 +1,3 @@ -from abc import ABC from dataclasses import replace from typing import Dict diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 09087d5f..b2a058bb 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -57,7 +57,7 @@ class Punctuation: if not isinstance(value, six.string_types): raise ValueError("[!] Punctuations must be of type str.") self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder - self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") def strip(self, text): """Remove all the punctuations by replacing with `space`. diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 58fc8762..750258af 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -270,7 +270,7 @@ class Wavegrad(BaseVocoder): ) -> None: pass - def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] @@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int - ): + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index ea4f377b..ac2c21a2 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -69,8 +69,8 @@ config = VitsConfig( print_eval=False, mixed_precision=False, sort_by_audio_len=True, - min_seq_len=32 * 256 * 4, - max_seq_len=160000, + min_audio_len=32 * 256 * 4, + max_audio_len=160000, output_path=output_path, datasets=dataset_config, characters={ diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 85b5ed7a..2783e4bd 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -4,6 +4,7 @@ import unittest import torch from torch import optim +from trainer.logging.tensorboard_logger import TensorboardLogger from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -11,7 +12,6 @@ from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor -from trainer.logging.tensorboard_logger import TensorboardLogger # pylint: disable=unused-variable diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 4018c6bd..204ff2f7 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -3,15 +3,14 @@ import os import unittest import torch -from TTS.tts.datasets.formatters import ljspeech +from trainer.logging.tensorboard_logger import TensorboardLogger from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset +from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.utils.speakers import SpeakerManager -from trainer.logging.tensorboard_logger import TensorboardLogger LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") @@ -31,7 +30,17 @@ class TestVits(unittest.TestCase): self.assertEqual(sr, 22050) spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) - mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False) + mel = wav_to_mel( + wav, + n_fft=1024, + num_mels=80, + sample_rate=sr, + hop_length=512, + win_length=1024, + fmin=0, + fmax=8000, + center=False, + ) mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) self.assertEqual((mel - mel2).abs().max(), 0) @@ -45,7 +54,7 @@ class TestVits(unittest.TestCase): def test_dataset(self): """TODO:""" - ... + ... def test_init_multispeaker(self): num_speakers = 10 @@ -164,7 +173,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -175,7 +184,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -196,7 +205,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.train() - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} @@ -211,7 +220,7 @@ class TestVits(unittest.TestCase): args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -246,7 +255,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) From 14c117978dff3e8c0a21e46f7063d0c3ce103e92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 21 Feb 2022 09:57:57 +0100 Subject: [PATCH 184/214] Fix return outputs --- TTS/tts/models/vits.py | 2 +- tests/tts_tests/test_vits.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ec6c9e5b..02542f71 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1078,7 +1078,7 @@ class Vits(BaseTTS): scores_disc_real, scores_disc_fake, ) - return {}, loss_dict + return outputs, loss_dict if optimizer_idx == 1: mel = batch["mel"] diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 204ff2f7..384234e5 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -410,7 +410,9 @@ class TestVits(unittest.TestCase): for _ in range(5): batch = self._create_batch(config, 2) for idx in [0, 1]: - _, loss_dict = model.train_step(batch, criterions, idx) + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) loss_dict["loss"].backward() optimizers[idx].step() optimizers[idx].zero_grad() From 83c5ddc5b70fa8c2ddd79a85c4c4a48987dd5502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 22 Feb 2022 11:30:41 +0100 Subject: [PATCH 185/214] Update imports --- TTS/speaker_encoder/utils/training.py | 2 +- TTS/tts/models/vits.py | 2 +- TTS/vocoder/models/gan.py | 2 +- TTS/vocoder/models/wavegrad.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index c64c46b7..0bc72af8 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -10,7 +10,7 @@ from TTS.config import load_config, register_config from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint @dataclass diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 02542f71..0c795ca1 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -29,7 +29,7 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index d4abaa0a..91467956 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -10,7 +10,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec -from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.models import setup_discriminator, setup_generator diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 750258af..95aa3cd2 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.utils.io import load_fsspec -from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.models.base_vocoder import BaseVocoder From 4c43eda414bcfb5ef7b4925f53402bdcf31a50c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:32:28 +0100 Subject: [PATCH 186/214] Update BaseTrainerModel --- TTS/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index 39cbeabc..ab52be81 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -5,11 +5,11 @@ import torch from coqpit import Coqpit from torch import nn -# pylint: skip-file class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. + """ @staticmethod @abstractmethod From bf540f43237b0b85e0b03d86c40347de42b5c76f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:35:42 +0100 Subject: [PATCH 187/214] Update imports for trainer --- TTS/bin/train_encoder.py | 2 ++ TTS/bin/train_tts.py | 1 + TTS/bin/train_vocoder.py | 1 + TTS/speaker_encoder/utils/training.py | 6 +++++- 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 5828411c..33724919 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,6 +10,8 @@ import torch from torch.utils.data import DataLoader from trainer.torch import NoamLR +from trainer.torch import NoamLR + from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 31813712..1bca7430 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 32ecd7bd..1745d6ab 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index 0bc72af8..c07915c9 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,3 +1,5 @@ +from asyncio.log import logger +from dataclasses import dataclass, field import os from dataclasses import dataclass, field @@ -7,10 +9,12 @@ from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config +from trainer import TrainerArgs, get_last_checkpoint from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer import get_last_checkpoint +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger @dataclass From e0f9be76c0f7c4f1f7de372aff93ef0eb6b32e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:36:27 +0100 Subject: [PATCH 188/214] Update test_run in wavernn and wavegrad --- TTS/vocoder/models/wavegrad.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 95aa3cd2..02c28c23 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -307,7 +307,9 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): + def get_data_loader( + self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int + ): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, From bed4afd4ee61dbbf9f95e443631f0acd21791d10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:47:42 +0100 Subject: [PATCH 189/214] Implement BaseVocabulary --- TTS/tts/utils/text/characters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 0ce65a90..f6c04370 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,3 +1,4 @@ +from abc import ABC from dataclasses import replace from typing import Dict From fe656659be9faf6e48fa0816bf158df66d7b8877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:50:13 +0100 Subject: [PATCH 190/214] Implement BaseTTS --- TTS/tts/models/base_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index dd6539a5..4e54b947 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram From acc83cd3e66fb6329c203d415636b7493221acba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:51:37 +0100 Subject: [PATCH 191/214] Update Vits model API --- TTS/tts/models/vits.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0c795ca1..6ff53c71 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,8 +1,9 @@ +import collections import math import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -544,8 +545,7 @@ class Vits(BaseTTS): ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + language_manager: LanguageManager = None,): super().__init__(config, ap, tokenizer, speaker_manager, language_manager) @@ -1483,6 +1483,10 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) +################################## +# VITS CHARACTERS +################################## + ################################## # VITS CHARACTERS From 1e414b3a09a7fe09965b76d0192139092acc5253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 12:37:27 +0100 Subject: [PATCH 192/214] Make stlye --- TTS/bin/train_encoder.py | 2 -- TTS/bin/train_tts.py | 1 - TTS/bin/train_vocoder.py | 1 - TTS/model.py | 4 ++-- TTS/speaker_encoder/utils/training.py | 7 +------ TTS/tts/models/vits.py | 7 ++++--- TTS/tts/utils/text/characters.py | 1 - TTS/vocoder/models/wavegrad.py | 4 +--- 8 files changed, 8 insertions(+), 19 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 33724919..5828411c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,8 +10,6 @@ import torch from torch.utils.data import DataLoader from trainer.torch import NoamLR -from trainer.torch import NoamLR - from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 1bca7430..31813712 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 1745d6ab..32ecd7bd 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/model.py b/TTS/model.py index ab52be81..39cbeabc 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -5,11 +5,11 @@ import torch from coqpit import Coqpit from torch import nn +# pylint: skip-file class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - """ + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @staticmethod @abstractmethod diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index c07915c9..7c58a232 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,20 +1,15 @@ -from asyncio.log import logger -from dataclasses import dataclass, field import os from dataclasses import dataclass, field from coqpit import Coqpit -from trainer import TrainerArgs +from trainer import TrainerArgs, get_last_checkpoint from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config -from trainer import TrainerArgs, get_last_checkpoint from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer.logging import logger_factory -from trainer.logging.console_logger import ConsoleLogger @dataclass diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 6ff53c71..04e84c62 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,9 +1,8 @@ -import collections import math import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -545,7 +544,8 @@ class Vits(BaseTTS): ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None,): + language_manager: LanguageManager = None, + ): super().__init__(config, ap, tokenizer, speaker_manager, language_manager) @@ -1483,6 +1483,7 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f6c04370..0ce65a90 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,4 +1,3 @@ -from abc import ABC from dataclasses import replace from typing import Dict diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 02c28c23..95aa3cd2 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int - ): + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, From 906339789256e74c772f7821dae25463020b2598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 22 Feb 2022 14:25:39 +0100 Subject: [PATCH 193/214] Fix FastSpeech config --- TTS/tts/configs/fast_speech_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py index f0c23593..16a76e21 100644 --- a/TTS/tts/configs/fast_speech_config.py +++ b/TTS/tts/configs/fast_speech_config.py @@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig): pitch_loss_alpha: float = 0.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 50000 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 From 7de5afc29ad50f88b8a6cebd3395c123f68cd00b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 22 Feb 2022 15:08:38 +0100 Subject: [PATCH 194/214] Add text processing tests --- Makefile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 3ef57285..d04cd976 100644 --- a/Makefile +++ b/Makefile @@ -27,11 +27,14 @@ test_zoo: ## run zoo tests. nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id inference_tests: ## run inference tests. - nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id + nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id -data_tests: ## run data tests. +data_tests: ## run data tests. nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id +test_text: ## run text tests. + nosetests tests.text_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.text_tests --nologcapture --with-id + test_failed: ## only run tests failed the last time. nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed From 690de1ab06c4709acfade4a399cbfa4299a29784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 23 Feb 2022 12:54:36 +0100 Subject: [PATCH 195/214] Update Characters and add more tests --- TTS/tts/utils/text/characters.py | 57 ++++++----------------------- tests/text_tests/test_characters.py | 46 ++++++++++++++++++++++- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 0ce65a90..1b375e4f 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -35,51 +35,6 @@ _diacrilics = "ɚ˞ɫ" _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -# def create_graphemes( -# characters=_characters, -# punctuations=_punctuations, -# pad=_pad, -# eos=_eos, -# bos=_bos, -# blank=_blank, -# unique=True, -# ): # pylint: disable=redefined-outer-name -# """Function to create default characters and phonemes""" -# # create graphemes -# = ( -# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) -# ) # this is to keep previous models compatible. -# _graphemes = list(characters) -# _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes -# _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes -# _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes -# _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes -# _graphemes = _graphemes + list(punctuations) -# return _graphemes, _phonemes - - -# def create_phonemes( -# phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True -# ): -# # create phonemes -# _phonemes = None -# _phonemes_sorted = ( -# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) -# ) # this is to keep previous models compatible. -# _phonemes = list(_phonemes_sorted) -# _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes -# _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes -# _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes -# _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes -# _phonemes = _phonemes + list(punctuations) -# _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) -# return _phonemes - - -# DEF_GRAPHEMES = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos) -# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) - - class BaseVocabulary: """Base Vocabulary class. @@ -98,18 +53,24 @@ class BaseVocabulary: @property def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" return self.char_to_id(self.pad) if self.pad else len(self.vocab) @property def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" return self.char_to_id(self.blank) if self.blank else len(self.vocab) @property def vocab(self): + """Return the vocabulary dictionary.""" return self._vocab @vocab.setter def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} self._id_to_char = { @@ -118,6 +79,7 @@ class BaseVocabulary: @staticmethod def init_from_config(config, **kwargs): + """Initialize from the given config.""" if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: return ( BaseVocabulary( @@ -133,15 +95,18 @@ class BaseVocabulary: @property def num_chars(self): - return max(self._vocab.values()) + 1 + """Return number of tokens in the vocabulary.""" + return len(self._vocab) def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" try: return self._char_to_id[char] except KeyError as e: raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" return self._id_to_char[idx] diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 5432c652..2ebb3bc3 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,9 +1,53 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes +from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary # pylint: disable=protected-access +class BaseVocabularyTest(unittest.TestCase): + def setUp(self): + self.phonemes = IPAPhonemes() + self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos) + self.empty_vocab = BaseVocabulary({}) + + def test_pad_id(self): + self.assertEqual(self.empty_vocab.pad_id, 0) + self.assertEqual(self.base_vocab.pad_id, self.phonemes.pad_id) + + def test_blank_id(self): + self.assertEqual(self.empty_vocab.blank_id, 0) + self.assertEqual(self.base_vocab.blank_id, self.phonemes.blank_id) + + def test_vocab(self): + self.assertEqual(self.empty_vocab.vocab, {}) + self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab) + + def test_init_from_config(self): + ... + + def test_num_chars(self): + self.assertEqual(self.empty_vocab.num_chars, 0) + self.assertEqual(self.base_vocab.num_chars, self.phonemes.num_chars) + + def test_char_to_id(self): + try: + self.empty_vocab.char_to_id("a") + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + self.assertEqual(self.base_vocab.char_to_id(k), self.phonemes.char_to_id(k)) + + def test_id_to_char(self): + try: + self.empty_vocab.id_to_char(0) + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + v = self.phonemes.char_to_id(k) + self.assertEqual(self.base_vocab.id_to_char(v), self.phonemes.id_to_char(v)) + class BaseCharacterTest(unittest.TestCase): def setUp(self): From 6a9f8074f09a5960e7bc270de69b593281553b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 1 Mar 2022 07:57:48 +0100 Subject: [PATCH 196/214] Fix TTSDataset --- TTS/tts/datasets/dataset.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d4d1a7e5..d8f16e4e 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -200,8 +200,8 @@ class TTSDataset(Dataset): def get_f0(self, idx): out_dict = self.f0_dataset[idx] - _, wav_file, *_ = _parse_sample(self.samples[idx]) - assert wav_file == out_dict["audio_file"] + item = self.samples[idx] + assert item["audio_file"] == out_dict["audio_file"] return out_dict @staticmethod @@ -263,10 +263,11 @@ class TTSDataset(Dataset): def _compute_lengths(samples): new_samples = [] for item in samples: - text, wav_file, *_ = _parse_sample(item) - audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio - text_lenght = len(text) - new_samples += [item + [audio_length, text_lenght]] + audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] return new_samples @staticmethod @@ -284,7 +285,7 @@ class TTSDataset(Dataset): @staticmethod def sort_by_length(samples: List[List]): - audio_lengths = [s[-2] for s in samples] + audio_lengths = [s["audio_length"] for s in samples] idxs = np.argsort(audio_lengths) # ascending order return idxs @@ -313,8 +314,8 @@ class TTSDataset(Dataset): samples = self._compute_lengths(self.samples) # sort items based on the sequence length in ascending order - text_lengths = [i[-1] for i in samples] - audio_lengths = [i[-2] for i in samples] + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) @@ -341,9 +342,9 @@ class TTSDataset(Dataset): samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items - audio_lengths = [s[-2] for s in samples] - text_lengths = [s[-1] for s in samples] - self.samples = [s[:-2] for s in samples] + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples if self.verbose: print(" | > Preprocessing samples") @@ -558,10 +559,10 @@ class PhonemeDataset(Dataset): self.precompute(precompute_num_workers) def __getitem__(self, index): - text, wav_file, *_ = _parse_sample(self.samples[index]) - ids = self.compute_or_load(wav_file, text) + item = self.samples[index] + ids = self.compute_or_load(item["audio_file"], item["text"]) ph_hat = self.tokenizer.ids_to_text(ids) - return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} def __len__(self): return len(self.samples) @@ -667,12 +668,12 @@ class F0Dataset: self.load_stats(cache_path) def __getitem__(self, idx): - _, wav_file, *_ = _parse_sample(self.samples[idx]) - f0 = self.compute_or_load(wav_file) + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"]) if self.normalize_f0: assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" f0 = self.normalize(f0) - return {"audio_file": wav_file, "f0": f0} + return {"audio_file": item["audio_file"], "f0": f0} def __len__(self): return len(self.samples) From a84499c5da45ae38ec5ff19a8a987cdccfcb8a63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 1 Mar 2022 07:58:12 +0100 Subject: [PATCH 197/214] Add text_tests --- .github/workflows/text_tests.yml | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 .github/workflows/text_tests.yml diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml new file mode 100644 index 00000000..6056cf6b --- /dev/null +++ b/.github/workflows/text_tests.yml @@ -0,0 +1,46 @@ +name: tts-tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make test_text From 942df0fb05ce70cd741d975f9d61bbfcb94e9e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 09:14:32 +0100 Subject: [PATCH 198/214] Update vits dataset --- TTS/tts/models/vits.py | 14 ++++++-------- TTS/utils/synthesizer.py | 20 -------------------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 04e84c62..036f22f2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -196,14 +196,12 @@ class VitsDataset(TTSDataset): def __getitem__(self, idx): item = self.samples[idx] + raw_text = item["text"] - text, wav_file, speaker_name, language_name, _ = _parse_sample(item) - raw_text = text + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) - wav, _ = load_audio(wav_file) - wav_filename = os.path.basename(wav_file) - - token_ids = self.get_token_ids(idx, text) + token_ids = self.get_token_ids(idx, item["text"]) # after phonemization the text length may change # this is a shameful 🤭 hack to prevent longer phonemes @@ -218,8 +216,8 @@ class VitsDataset(TTSDataset): "token_len": len(token_ids), "wav": wav, "wav_file": wav_filename, - "speaker_name": speaker_name, - "language_name": language_name, + "speaker_name": item["speaker_name"], + "language_name": item["language"], } @property diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 6821e975..d1abc907 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -126,26 +126,6 @@ class Synthesizer(object): self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path - def _is_use_speaker_embedding(self): - """Check if the speaker embedding is used in the model""" - # we handle here the case that some models use model_args some don't - use_speaker_embedding = False - if hasattr(self.tts_config, "model_args"): - use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False) - use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False) - return use_speaker_embedding - - def _is_use_d_vector_file(self): - """Check if the d-vector file is used in the model""" - # we handle here the case that some models use model_args some don't - use_d_vector_file = False - if hasattr(self.tts_config, "model_args"): - config = self.tts_config.model_args - use_d_vector_file = config.get("use_d_vector_file", False) - config = self.tts_config - use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False) - return use_d_vector_file - def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. From 27b67b7945ca7cd8c7ad1ce60744892a4eb47716 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 09:15:20 +0100 Subject: [PATCH 199/214] Fix import --- TTS/tts/configs/tacotron_config.py | 3 +++ TTS/tts/models/forward_tts.py | 2 +- tests/data_tests/test_loader.py | 6 +++--- tests/tts_tests/test_align_tts_train.py | 2 +- tests/tts_tests/test_fast_pitch_speaker_emb_train.py | 4 ++-- tests/tts_tests/test_fast_pitch_train.py | 2 +- tests/tts_tests/test_glow_tts_d-vectors_train.py | 3 +-- tests/tts_tests/test_glow_tts_speaker_emb_train.py | 3 +-- tests/tts_tests/test_glow_tts_train.py | 2 +- tests/tts_tests/test_speedy_speech_train.py | 2 +- tests/tts_tests/test_tacotron2_d-vectors_train.py | 2 +- tests/tts_tests/test_tacotron2_speaker_emb_train.py | 3 ++- tests/tts_tests/test_tacotron2_train.py | 2 +- tests/tts_tests/test_tacotron_train.py | 2 +- tests/tts_tests/test_vits_multilingual_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_multilingual_train-d_vectors.py | 2 +- tests/tts_tests/test_vits_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_train.py | 2 +- 18 files changed, 24 insertions(+), 22 deletions(-) diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index d6edd267..5193c224 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -83,6 +83,8 @@ class TacotronConfig(BaseTTSConfig): ddc_r (int): reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. use_speaker_embedding (bool): enable / disable using speaker embeddings for multi-speaker models. If set True, the model is in the multi-speaker mode. Defaults to False. @@ -176,6 +178,7 @@ class TacotronConfig(BaseTTSConfig): ddc_r: int = 6 # multi-speaker settings + speakers_file: str = None use_speaker_embedding: bool = False speaker_embedding_dim: int = 512 use_d_vector_file: bool = False diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index db8fef2d..a1273f7f 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -261,7 +261,7 @@ class ForwardTTS(BaseTTS): # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") - self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels) + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 4d8cc68a..2727bbdd 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -44,13 +44,13 @@ class TestTTSDataset(unittest.TestCase): self.max_loader_iter = 4 self.ap = AudioProcessor(**c.audio) - def _create_dataloader(self, batch_size, r, bgs): + def _create_dataloader(self, batch_size, r, bgs, start_by_longest=False): # load dataset meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2) items = meta_data_train + meta_data_eval - tokenizer = TTSTokenizer.init_from_config(c) + tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( outputs_per_step=r, compute_linear_spec=True, @@ -77,7 +77,7 @@ class TestTTSDataset(unittest.TestCase): def test_loader(self): if ok_ljspeech: - dataloader, dataset = self._create_dataloader(2, c.r, 0) + dataloader, dataset = self._create_dataloader(1, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index d5115af6..6c68d8c9 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py index 59e90e0a..88505988 100644 --- a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -5,9 +5,9 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint -config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index bbfbb823..5a51f0bb 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -5,7 +5,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests/test_glow_tts_d-vectors_train.py index c85e6bcd..dd5e954e 100644 --- a/tests/tts_tests/test_glow_tts_d-vectors_train.py +++ b/tests/tts_tests/test_glow_tts_d-vectors_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -17,7 +17,6 @@ config = GlowTTSConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py index 7e6aabde..df86cf05 100644 --- a/tests/tts_tests/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -17,7 +17,6 @@ config = GlowTTSConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index e5dc44ee..3a1c4a68 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 7e938a40..98cf8e09 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index 0bc31449..e5f83804 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 653933dd..2dd50c73 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -28,6 +28,7 @@ config = Tacotron2Config( "Be a voice, not an echo.", ], use_speaker_embedding=True, + num_speakers=4, max_decoder_steps=50, ) diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 76727edf..a45065b2 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 02491e64..96c63162 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index afa60a1b..c09f8498 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -5,7 +5,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index b0744103..8607a8f7 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -5,7 +5,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 1aecc596..8a586076 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index ec9a5915..76c88682 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from TTS.utils.trainer_utils import get_last_checkpoint +from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") From c68885b3fdbd38af63b289d1ca3e5ff03b54ad7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 13:20:23 +0100 Subject: [PATCH 200/214] Update Vits speaker encoder init --- TTS/tts/models/vits.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 036f22f2..8c15103f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -654,7 +654,7 @@ class Vits(BaseTTS): # TODO: make this a function if self.args.use_speaker_encoder_as_loss: if self.speaker_manager.speaker_encoder is None and ( - not config.speaker_encoder_model_path or not config.speaker_encoder_config_path + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path ): raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" @@ -1445,13 +1445,13 @@ class Vits(BaseTTS): # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers - if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape: - num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0] + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") - emb_g = state["model"]["vits.emb_g.weight"] + emb_g = state["model"]["emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) - state["model"]["vits.emb_g.weight"] = emb_g + state["model"]["emb_g.weight"] = emb_g # load the model weights self.load_state_dict(state["model"], strict=strict) @@ -1479,14 +1479,12 @@ class Vits(BaseTTS): tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path is not None: + speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path, + config.model_args.speaker_encoder_config_path) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) - -################################## -# VITS CHARACTERS -################################## - - ################################## # VITS CHARACTERS ################################## From 1425a023fe4bc6bda8578295aeeeb02af78cc082 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 13:25:35 +0100 Subject: [PATCH 201/214] Make style and lint --- TTS/bin/extract_tts_spectrograms.py | 4 +- TTS/bin/find_unique_chars.py | 4 +- TTS/bin/find_unique_phonemes.py | 4 +- TTS/bin/train_tts.py | 7 +++- TTS/tts/datasets/__init__.py | 37 +++++++++++-------- TTS/tts/datasets/formatters.py | 2 +- TTS/tts/models/glow_tts.py | 4 +- TTS/tts/models/vits.py | 8 ++-- TTS/vocoder/models/gan.py | 2 +- TTS/vocoder/models/wavegrad.py | 2 +- tests/data_tests/test_loader.py | 8 ++-- tests/inference_tests/test_synthesizer.py | 3 +- tests/text_tests/test_characters.py | 15 ++++++-- tests/tts_tests/test_align_tts_train.py | 5 ++- .../test_fast_pitch_speaker_emb_train.py | 5 ++- tests/tts_tests/test_fast_pitch_train.py | 5 ++- .../test_glow_tts_d-vectors_train.py | 5 ++- .../test_glow_tts_speaker_emb_train.py | 5 ++- tests/tts_tests/test_glow_tts_train.py | 5 ++- tests/tts_tests/test_speedy_speech_train.py | 5 ++- .../test_tacotron2_d-vectors_train.py | 5 ++- .../test_tacotron2_speaker_emb_train.py | 5 ++- tests/tts_tests/test_tacotron2_train.py | 5 ++- tests/tts_tests/test_tacotron_train.py | 5 ++- ...est_vits_multilingual_speaker_emb_train.py | 5 ++- .../test_vits_multilingual_train-d_vectors.py | 5 ++- .../tts_tests/test_vits_speaker_emb_train.py | 5 ++- tests/tts_tests/test_vits_train.py | 5 ++- 28 files changed, 108 insertions(+), 67 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 2a2c0b71..fa63c46a 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # load data instances - meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) # use eval and training partitions meta_data = meta_data_train + meta_data_eval diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 541e971b..4689dcad 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -23,7 +23,9 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) items = train_items + eval_items diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 8fe48b2f..0ae74bd4 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -40,7 +40,9 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) items = train_items + eval_items print("Num items:", len(items)) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 31813712..976b74af 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -44,7 +44,12 @@ def main(): config = register_config(config_base.model)() # load training samples - train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size) + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) # init the model from config model = setup_model(config, train_samples + eval_samples) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index dde85808..6c7c9edd 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -12,20 +12,20 @@ from TTS.tts.datasets.formatters import * def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. - Args: -<<<<<<< HEAD - items (List[List]): - A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + Args: + <<<<<<< HEAD + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. - eval_split_max_size (int): - Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). - eval_split_size (float): - If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. - If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). -======= - items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. ->>>>>>> Fix docstring + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + ======= + items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. + >>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 @@ -37,7 +37,11 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): else: eval_split_size = int(len(items) * eval_split_size) - assert eval_split_size > 0, " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(1/len(items)) + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: @@ -56,8 +60,11 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): def load_tts_samples( - datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None, - eval_split_max_size=None, eval_split_size=0.01 + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, ) -> Tuple[List[List], List[List]]: """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 4592ccce..aacfc647 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -132,7 +132,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg speaker_id = 0 for idx, line in enumerate(ttf): # 2 samples per speaker to avoid eval split issues - if idx%2 == 0: + if idx % 2 == 0: speaker_id += 1 cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index c30f043a..fea570a6 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -183,8 +183,8 @@ class GlowTTS(BaseTTS): if g is not None: if hasattr(self, "emb_g"): # use speaker embedding layer - if not g.size(): # if is a scalar - g = g.unsqueeze(0) # unsqueeze + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] else: # use d-vector diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 8c15103f..1ad8807f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -14,6 +14,7 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample @@ -29,7 +30,6 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -1481,10 +1481,12 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) if config.model_args.speaker_encoder_model_path is not None: - speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path, - config.model_args.speaker_encoder_config_path) + speaker_manager.init_speaker_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 91467956..3b8a3fbe 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -7,10 +7,10 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.models import setup_discriminator, setup_generator diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 95aa3cd2..c4968f1f 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -8,9 +8,9 @@ from torch import nn from torch.nn.utils import weight_norm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.io import load_fsspec -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.models.base_vocoder import BaseVocoder diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 2727bbdd..0562fbf7 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs.shared_configs import BaseTTSConfig, BaseDatasetConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -24,7 +24,7 @@ c.data_path = "tests/data/ljspeech/" ok_ljspeech = os.path.exists(c.data_path) dataset_config = BaseDatasetConfig( - name="ljspeech_test", # ljspeech_test to multi-speaker + name="ljspeech_test", # ljspeech_test to multi-speaker meta_file_train="metadata.csv", meta_file_val=None, path=c.data_path, @@ -106,9 +106,9 @@ class TestTTSDataset(unittest.TestCase): # make sure that the computed mels and the waveform match and correctly computed mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) # remove padding in mel-spectrogram - mel_dataloader = mel_input[0].T.numpy()[:, :mel_lengths[0]] + mel_dataloader = mel_input[0].T.numpy()[:, : mel_lengths[0]] # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding - mel_new = mel_new[:, :mel_lengths[0]] + mel_new = mel_new[:, : mel_lengths[0]] ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] self.assertLess(abs(mel_diff.sum()), 1e-5) diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index 97878574..d643cb81 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -1,13 +1,12 @@ import os import unittest +from tests import get_tests_output_path from TTS.config import load_config from TTS.tts.models import setup_model from TTS.utils.io import save_checkpoint from TTS.utils.synthesizer import Synthesizer -from tests import get_tests_output_path - class SynthesizerTest(unittest.TestCase): # pylint: disable=R0201 diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 2ebb3bc3..8f40656a 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,13 +1,20 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, Graphemes, IPAPhonemes # pylint: disable=protected-access + class BaseVocabularyTest(unittest.TestCase): def setUp(self): self.phonemes = IPAPhonemes() - self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos) + self.base_vocab = BaseVocabulary( + vocab=self.phonemes._vocab, + pad=self.phonemes.pad, + blank=self.phonemes.blank, + bos=self.phonemes.bos, + eos=self.phonemes.eos, + ) self.empty_vocab = BaseVocabulary({}) def test_pad_id(self): @@ -22,8 +29,8 @@ class BaseVocabularyTest(unittest.TestCase): self.assertEqual(self.empty_vocab.vocab, {}) self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab) - def test_init_from_config(self): - ... + # def test_init_from_config(self): + # ... def test_num_chars(self): self.assertEqual(self.empty_vocab.num_chars, 0) diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index 6c68d8c9..85dfbbcb 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py index 88505988..37faf449 100644 --- a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -69,7 +70,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 5a51f0bb..d2d78af4 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -70,7 +71,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests/test_glow_tts_d-vectors_train.py index dd5e954e..14f9e4d2 100644 --- a/tests/tts_tests/test_glow_tts_d-vectors_train.py +++ b/tests/tts_tests/test_glow_tts_d-vectors_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -56,7 +57,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py index df86cf05..c327332e 100644 --- a/tests/tts_tests/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -53,7 +54,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index 3a1c4a68..b0acf004 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -52,7 +53,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 98cf8e09..9a26d253 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example for it.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index e5f83804..6b003f2c 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -56,7 +57,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 2dd50c73..b9f4de0b 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -54,7 +55,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index a45065b2..8c30d9f9 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 96c63162..40cd2d3d 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -52,7 +53,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index c09f8498..0c7672d7 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -85,7 +86,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech" languae_id = "en" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index 8607a8f7..a8e2020e 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -89,7 +90,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" languae_id = "en" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 8a586076..c928cee4 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -60,7 +61,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index 76c88682..003f99a8 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) From fd71893ea9662c5254f6edc52aae255d9160b6e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 18:00:29 +0100 Subject: [PATCH 202/214] Add missing deps for CI tests --- .compute | 16 ---------------- .github/workflows/text_tests.yml | 2 ++ .github/workflows/tts_tests.yml | 2 ++ setup.py | 6 ++---- 4 files changed, 6 insertions(+), 20 deletions(-) delete mode 100644 .compute diff --git a/.compute b/.compute deleted file mode 100644 index 9786a689..00000000 --- a/.compute +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -yes | apt-get install sox -yes | apt-get install ffmpeg -yes | apt-get install tmux -yes | apt-get install zsh -sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" -pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl -sudo sh install.sh -# pip install pytorch==1.7.0+cu100 -# python3 setup.py develop -# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/ -# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/ -# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/ -# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360 -# python3 distribute.py --config_path config.json -while true; do sleep 1000000; done diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml index 6056cf6b..e06a25ad 100644 --- a/.github/workflows/text_tests.yml +++ b/.github/workflows/text_tests.yml @@ -35,6 +35,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y --no-install-recommends git make gcc + sudo apt-get install espeak + sudo apt-get install espeak-ng make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml index e352a117..0a5891ee 100644 --- a/.github/workflows/tts_tests.yml +++ b/.github/workflows/tts_tests.yml @@ -35,6 +35,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y --no-install-recommends git make gcc + sudo apt-get install espeak + sudo apt-get install espeak-ng make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel diff --git a/setup.py b/setup.py index 1d4dbf1c..96173fec 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ # ,+++*. . .*++, ,++*. .*+++* # *+, .,*++**. .**++**. ,+* # .+* *+, -# *+. .+* -# *+* +++ +++ *+* +# *+. Coqui .+* +# *+* +++ TTS +++ *+* # .+++*. . . *+++. # ,+* *+++*... ...*+++* *+, # .++. .""""+++++++****+++++++"""". ++. @@ -35,8 +35,6 @@ if LooseVersion(sys.version) < LooseVersion("3.6") or LooseVersion(sys.version) raise RuntimeError("TTS requires python >= 3.6 and <=3.10 " "but your Python version is {}".format(sys.version)) -cwd = os.path.dirname(os.path.abspath(__file__)) - cwd = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(cwd, "TTS", "VERSION")) as fin: version = fin.read().strip() From 6cb00be795d94a2d36316e0c36750d738890e7f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 2 Mar 2022 18:04:49 +0100 Subject: [PATCH 203/214] Update your_tts model URL --- TTS/.models.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/.models.json b/TTS/.models.json index 61a3257d..2e6c0ebf 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -4,7 +4,7 @@ "multi-dataset":{ "your_tts":{ "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", "default_vocoder": null, "commit": "e9a1953e", "license": "CC BY-NC-ND 4.0", From dd4287de1fce944c77cef7498e0daf1a2154abfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Mar 2022 20:23:00 +0100 Subject: [PATCH 204/214] Update models --- TTS/.models.json | 21 ++++++--------------- TTS/tts/models/vits.py | 4 ++-- TTS/tts/utils/synthesis.py | 2 +- TTS/tts/utils/text/tokenizer.py | 3 +++ 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index 2e6c0ebf..366358be 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -33,7 +33,7 @@ }, "tacotron2-DDC_ph": { "description": "Tacotron2 with Double Decoder Consistency with phonemes.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip", "default_vocoder": "vocoder_models/en/ljspeech/univnet", "commit": "3900448", "author": "Eren Gölge @erogol", @@ -71,7 +71,7 @@ }, "vits": { "description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--vits.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--vits.zip", "default_vocoder": null, "commit": "3900448", "author": "Eren Gölge @erogol", @@ -89,18 +89,9 @@ } }, "vctk": { - "sc-glow-tts": { - "description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip", - "default_vocoder": "vocoder_models/en/vctk/hifigan_v2", - "commit": "b531fa69", - "author": "Edresson Casanova", - "license": "", - "contact": "" - }, "vits": { "description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--vctk--vits.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--vits.zip", "default_vocoder": null, "commit": "3900448", "author": "Eren @erogol", @@ -109,7 +100,7 @@ }, "fast_pitch":{ "description": "FastPitch model trained on VCTK dataseset.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--fast_pitch.zip", "default_vocoder": null, "commit": "bdab788d", "author": "Eren @erogol", @@ -156,7 +147,7 @@ "uk":{ "mai": { "glow-tts": { - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--uk--mailabs--glow-tts.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--uk--mai--glow-tts.zip", "author":"@robinhad", "commit": "bdab788d", "license": "MIT", @@ -168,7 +159,7 @@ "zh-CN": { "baker": { "tacotron2-DDC-GST": { - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", "commit": "unknown", "author": "@kirianguiller", "default_vocoder": null diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 1ad8807f..a43e081c 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1470,7 +1470,7 @@ class Vits(BaseTTS): """ from TTS.utils.audio import AudioProcessor - upsample_rate = math.prod(config.model_args.upsample_rates_decoder) + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() assert ( upsample_rate == config.audio.hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" @@ -1480,7 +1480,7 @@ class Vits(BaseTTS): speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) - if config.model_args.speaker_encoder_model_path is not None: + if config.model_args.speaker_encoder_model_path: speaker_manager.init_speaker_encoder( config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path ) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 377f32de..4ec84a3d 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -167,7 +167,7 @@ def synthesis( style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) # convert text to sequence of token IDs text_inputs = np.asarray( - model.tokenizer.text_to_ids(text), + model.tokenizer.text_to_ids(text, language=language_id), dtype=np.int32, ) # pass tensors to backend diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 50a5f519..f0d85a44 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -93,6 +93,9 @@ class TTSTokenizer: language(str): The language code of the text. Defaults to None. + TODO: + - Add support for language-specific processing. + 1. Text normalizatin 2. Phonemization (if use_phonemes is True) 3. Add blank char between characters From bec543b3a5e150e114f0a224b8b84285b0fe46d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 11:48:38 +0100 Subject: [PATCH 205/214] Update zoo tests --- .github/workflows/zoo_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/zoo_tests.yml b/.github/workflows/zoo_tests.yml index f973dd0e..94d54200 100644 --- a/.github/workflows/zoo_tests.yml +++ b/.github/workflows/zoo_tests.yml @@ -35,6 +35,7 @@ jobs: run: | sudo apt-get update sudo apt-get install -y git make gcc + sudo apt-get install espeak espeak-ng make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel From 764c7fa4a477d54762cf20d4d4d0ef26534e8a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 12:09:54 +0100 Subject: [PATCH 206/214] Rename phoneme_cleaners --- TTS/tts/utils/synthesis.py | 4 ---- TTS/tts/utils/text/cleaners.py | 21 ++++++++++----------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 4ec84a3d..b6e19ab4 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -119,7 +119,6 @@ def synthesis( do_trim_silence=False, d_vector=None, language_id=None, - language_name=None, ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -154,9 +153,6 @@ def synthesis( language_id (int): Language ID passed to the language embedding layer in multi-langual model. Defaults to None. - - language_name (str): - Language name corresponding to the language code used by the phonemizer. Defaults to None. """ # GST processing style_mel = None diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 0ff3e930..cdea3569 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -100,6 +100,16 @@ def english_cleaners(text): return text +def english_phoneme_cleaners(text): + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + def french_cleaners(text): """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" text = expand_abbreviations(text, lang="fr") @@ -126,17 +136,6 @@ def chinese_mandarin_cleaners(text: str) -> str: return text -def phoneme_cleaners(text): - """Pipeline for phonemes mode, including number and abbreviation expansion.""" - text = en_normalize_numbers(text) - # text = convert_to_ascii(text) - text = expand_abbreviations(text) - text = replace_symbols(text) - text = remove_aux_symbols(text) - text = collapse_whitespace(text) - return text - - def multilingual_cleaners(text): """Pipeline for multilingual text""" text = lowercase(text) From e9d9028b4d16ca0c79b8d1918faecba3f0ded22b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 12:57:06 +0100 Subject: [PATCH 207/214] Revert cleaner name --- TTS/tts/utils/text/cleaners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index cdea3569..f02f8fb4 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -100,7 +100,7 @@ def english_cleaners(text): return text -def english_phoneme_cleaners(text): +def phoneme_cleaners(text): """Pipeline for phonemes mode, including number and abbreviation expansion.""" text = en_normalize_numbers(text) text = expand_abbreviations(text) From 00edd3c99b21db4e80211d3279f91bbd2533bfbf Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 1 Mar 2022 17:19:38 +0100 Subject: [PATCH 208/214] feat: add dotbot --- docs/source/_templates/page.html | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 docs/source/_templates/page.html diff --git a/docs/source/_templates/page.html b/docs/source/_templates/page.html new file mode 100644 index 00000000..633b7661 --- /dev/null +++ b/docs/source/_templates/page.html @@ -0,0 +1,23 @@ + +{% extends "!page.html" %} +{% block scripts %} + {{ super() }} + + + + + + + +{% endblock %} From 6716b3b214d3d84e0a0a0c26ea622b07a904307e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 14:10:16 +0100 Subject: [PATCH 209/214] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e7774888..80fa5dea 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models - Detailed training logs on the terminal and Tensorboard. - Support for Multi-speaker TTS. - Efficient, flexible, lightweight but feature complete `Trainer API`. -- Released and read-to-use models. +- Released and ready-to-use models. - Tools to curate Text2Speech datasets under```dataset_analysis```. - Utilities to use and test your models. - Modular (but not too much) code base enabling easy implementation of new ideas. From bdebe3d83eb5153dea3ef403cd0928db9a1351d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 14:19:46 +0100 Subject: [PATCH 210/214] Fix typos --- docs/source/finetuning.md | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/docs/source/finetuning.md b/docs/source/finetuning.md index 42b9e518..7d7ef1cb 100644 --- a/docs/source/finetuning.md +++ b/docs/source/finetuning.md @@ -9,33 +9,33 @@ them and fine-tune it for your own dataset. This will help you in two main ways: 1. Faster learning Since a pre-trained model has already learned features that are relevant for the task, it will converge faster on - a new dataset. This will reduce the cost of training and let you experient faster. + a new dataset. This will reduce the cost of training and let you experiment faster. 2. Better resutls with small datasets Deep learning models are data hungry and they give better performance with more data. However, it is not always - possible to have this abondance, especially in domain. For instance, LJSpeech dataset, that we released most of - our English models with, is almost 24 hours long. And it requires for someone to collect thid amount of data with - a help of a voice talent takes weeks. + possible to have this abundance, especially in specific domains. For instance, the LJSpeech dataset, that we released most of + our English models with, is almost 24 hours long. It takes weeks to record this amount of data with + the help of a voice actor. - Fine-tuning cames to rescue in this case. You can take one of our pre-trained models and fine-tune it for your own - speech dataset and achive reasonable results with only a couple of hours in the worse case. + Fine-tuning comes to the rescue in this case. You can take one of our pre-trained models and fine-tune it on your own + speech dataset and achive reasonable results with only a couple of hours of data. - However, note that, fine-tuning does not promise great results. The model performance is still depends on the + However, note that, fine-tuning does not ensure great results. The model performance is still depends on the {ref}`dataset quality ` and the hyper-parameters you choose for fine-tuning. Therefore, - it still demands a bit of tinkering. + it still takes a bit of tinkering. ## Steps to fine-tune a 🐸 TTS model 1. Setup your dataset. - You need to format your target dataset in a certain way so that 🐸TTS data loader would be able to load it for the + You need to format your target dataset in a certain way so that 🐸TTS data loader will be able to load it for the training. Please see {ref}`this page ` for more information about formatting. 2. Choose the model you want to fine-tune. - You can list the availabe models on terminal as + You can list the availabe models in the command line with ```bash tts --list_models @@ -43,15 +43,15 @@ them and fine-tune it for your own dataset. This will help you in two main ways: The command above lists the the models in a naming format as ```///```. - Or you can manually check `.model.json` file in the project directory. + Or you can manually check the `.model.json` file in the project directory. You should choose the model based on your requirements. Some models are fast and some are better in speech quality. - One lazy way to check a model is running the model on the hardware you want to use and see how it works. For + One lazy way to test a model is running the model on the hardware you want to use and see how it works. For simple testing, you can use the `tts` command on the terminal. For more info see {ref}`here `. 3. Download the model. - You can download the model by `tts` command. If you run `tts` with a particular model, it will download automatically + You can download the model by using the `tts` command. If you run `tts` with a particular model, it will download it automatically and the model path will be printed on the terminal. ```bash @@ -78,12 +78,12 @@ them and fine-tune it for your own dataset. This will help you in two main ways: - `run_name` field: This is the name of the run. This is used to name the output directory and the entry in the logging dashboard. - `output_path` field: This is the path where the fine-tuned model is saved. - - `lr` field: You may need to use a smaller learning rate for fine-tuning not to impair the features learned by the + - `lr` field: You may need to use a smaller learning rate for fine-tuning to not lose the features learned by the pre-trained model with big update steps. - `audio` fields: Different datasets have different audio characteristics. You must check the current audio parameters and make sure that the values reflect your dataset. For instance, your dataset might have a different audio sampling rate. - Apart from these above, you should check the whole configuration file and make sure that the values are correct for + Apart from the parameters above, you should check the whole configuration file and make sure that the values are correct for your dataset and training. 5. Start fine-tuning. @@ -112,4 +112,3 @@ them and fine-tune it for your own dataset. This will help you in two main ways: --coqpit.lr 0.00001 ``` - From 45f1e1f786db8a1f2e7866245b76eca1dda34daf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Mar 2022 14:24:19 +0100 Subject: [PATCH 211/214] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c60e0817..6e30c26e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # core deps numpy==1.19.5 cython -scipy>=0.19.0 +scipy>=1.4.0 torch>=1.7 torchaudio soundfile From dc280819be9418d5a72f4fecfa34690e07f5feda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 7 Mar 2022 12:08:09 +0100 Subject: [PATCH 212/214] Add new models --- TTS/.models.json | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/TTS/.models.json b/TTS/.models.json index 366358be..801b8468 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -197,6 +197,52 @@ "commit": "401fbd89" } } + }, + "tr":{ + "common-voice": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--tr--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/tr/common-voice/hifigan", + "license": "MIT", + "description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "commit": null + } + } + }, + "it": { + "mai_female": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + } + }, + "mai_male": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + } + } } }, "vocoder_models": { @@ -315,6 +361,17 @@ "contact": "" } } + }, + "tr":{ + "common-voice": { + "hifigan":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/vocoder_models--tr--common-voice--hifigan.zip", + "description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "license": "MIT", + "commit": null + } + } } } } \ No newline at end of file From ee02bc3823b48d8c02a0ec2d64502a93eca41dde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 7 Mar 2022 12:08:22 +0100 Subject: [PATCH 213/214] Bump up to v0.6.0 --- TTS/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/VERSION b/TTS/VERSION index 79a2734b..09a3acfa 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.5.0 \ No newline at end of file +0.6.0 \ No newline at end of file From d87985cde172df13b30ba861db791a800e6e871c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 7 Mar 2022 12:27:13 +0100 Subject: [PATCH 214/214] Update docs --- docs/source/formatting_your_dataset.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/formatting_your_dataset.md b/docs/source/formatting_your_dataset.md index 5b1d9801..294d2b29 100644 --- a/docs/source/formatting_your_dataset.md +++ b/docs/source/formatting_your_dataset.md @@ -19,15 +19,15 @@ Let's assume you created the audio clips and their transcription. You can collec You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text. -We recommend the following format delimited by `||`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc. +We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc. ``` # metadata.txt -audio1||This is my sentence. -audio2||This is maybe my sentence. -audio3||This is certainly my sentence. -audio4||Let this be your sentence. +audio1|This is my sentence. +audio2|This is maybe my sentence. +audio3|This is certainly my sentence. +audio4|Let this be your sentence. ... ```