Fix emotion unit test

This commit is contained in:
Edresson Casanova 2022-03-31 08:34:08 -03:00
parent 047cebd7b8
commit b692c77e6a
3 changed files with 8 additions and 9 deletions

View File

@ -134,7 +134,6 @@ class EmbeddingManager(BaseIDManager):
file_path (str): Path to the target json file.
"""
self.embeddings = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.embeddings.values()})
self.ids = {name: i for i, name in enumerate(speakers)}

View File

@ -72,7 +72,7 @@ class TestVits(unittest.TestCase):
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
model = Vits(args)
self.assertEqual(model.embedded_speaker_dim, 101)
self.assertEqual(model.cond_embedding_dim, 101)
def test_init_multilingual(self):
args = VitsArgs(language_ids_file=None, use_language_embedding=False)
@ -163,11 +163,11 @@ class TestVits(unittest.TestCase):
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
)
if encoder_config:
self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
self.assertEqual(output_dict["gt_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
self.assertEqual(output_dict["syn_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
else:
self.assertEqual(output_dict["gt_spk_emb"], None)
self.assertEqual(output_dict["syn_spk_emb"], None)
self.assertEqual(output_dict["gt_cons_emb"], None)
self.assertEqual(output_dict["syn_cons_emb"], None)
def test_forward(self):
num_speakers = 0
@ -503,4 +503,4 @@ class TestVits(unittest.TestCase):
model = Vits.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)
self.assertTrue(model.cond_embedding_dim == config.d_vector_dim)

View File

@ -76,8 +76,8 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-3"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json")
continue_speakers_path = config.model_args.d_vector_file
continue_emotion_path = config.model_args.external_emotions_embs_file
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"