From 7a0eba517fcfb4e349354f8b3dcff9916a2ee644 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 15 Mar 2022 13:09:58 +0000 Subject: [PATCH] Add emotion external embeddings training unit test --- TTS/tts/models/vits.py | 9 +- ...ts_d_vector_with_external_emotion_train.py | 89 +++++++++++++++++++ 2 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e1d34395..df010de6 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -602,7 +602,7 @@ class Vits(BaseTTS): self.init_multispeaker(config) self.init_multilingual(config) self.init_upsampling() - self.init_emotion(config, emotion_manager) + self.init_emotion(emotion_manager) self.init_consistency_loss() self.length_scale = self.args.length_scale @@ -822,7 +822,7 @@ class Vits(BaseTTS): raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") print(" > Text Encoder was reinit.") - def init_emotion(self, config: Coqpit, emotion_manager: EmotionManager): + def init_emotion(self, emotion_manager: EmotionManager): # pylint: disable=attribute-defined-outside-init """Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer or with external `embeddings` computed from a emotion encoder model. @@ -830,7 +830,6 @@ class Vits(BaseTTS): You must provide a `emotion_manager` at initialization to set up the emotion modules. Args: - config (Coqpit): Model configuration. emotion_manager (Coqpit): Emotion Manager. """ self.emotion_manager = emotion_manager @@ -1031,7 +1030,7 @@ class Vits(BaseTTS): # concat the emotion embedding and speaker embedding if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1] + g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] # language embedding lang_emb = None @@ -1146,7 +1145,7 @@ class Vits(BaseTTS): eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if eg is not None and g is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1] # language embedding diff --git a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py new file mode 100644 index 00000000..75fba5fc --- /dev/null +++ b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py @@ -0,0 +1,89 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = False +config.use_speaker_embedding = False +config.model_args.use_d_vector_file = True +config.use_d_vector_file = True +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.d_vector_dim = 256 + +# emotion +config.model_args.use_external_emotions_embeddings = True +config.model_args.use_emotion_embedding = False +config.model_args.emotion_embedding_dim = 256 +config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" + +# consistency loss +# config.model_args.use_emotion_encoder_as_loss = True +# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar" +# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json" + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +emotion_id = "ljspeech-3" +continue_speakers_path = os.path.join(continue_path, "speakers.json") +continue_emotion_path = os.path.join(continue_path, "speakers.json") + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)