From 836c4c6801c815d90c81e35120566f23636a9d91 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 31 Mar 2022 08:34:08 -0300 Subject: [PATCH] Fix emotion unit test --- tests/tts_tests/test_vits.py | 12 ++++++------ ...test_vits_d_vector_with_external_emotion_train.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index b9cebb5a..8f7ed888 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -72,7 +72,7 @@ class TestVits(unittest.TestCase): args = VitsArgs(d_vector_dim=101, use_d_vector_file=True) model = Vits(args) - self.assertEqual(model.embedded_speaker_dim, 101) + self.assertEqual(model.cond_embedding_dim, 101) def test_init_multilingual(self): args = VitsArgs(language_ids_file=None, use_language_embedding=False) @@ -163,11 +163,11 @@ class TestVits(unittest.TestCase): output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) if encoder_config: - self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) - self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["gt_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) else: - self.assertEqual(output_dict["gt_spk_emb"], None) - self.assertEqual(output_dict["syn_spk_emb"], None) + self.assertEqual(output_dict["gt_cons_emb"], None) + self.assertEqual(output_dict["syn_cons_emb"], None) def test_forward(self): num_speakers = 0 @@ -573,4 +573,4 @@ class TestVits(unittest.TestCase): model = Vits.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) - self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) + self.assertTrue(model.cond_embedding_dim == config.d_vector_dim) diff --git a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py index 75fba5fc..55f0492d 100644 --- a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py +++ b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py @@ -76,8 +76,8 @@ continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" emotion_id = "ljspeech-3" -continue_speakers_path = os.path.join(continue_path, "speakers.json") -continue_emotion_path = os.path.join(continue_path, "speakers.json") +continue_speakers_path = config.model_args.d_vector_file +continue_emotion_path = config.model_args.external_emotions_embs_file inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"