mirror of https://github.com/coqui-ai/TTS.git
Fix emotion unit test
This commit is contained in:
parent
e8c4417f07
commit
836c4c6801
|
@ -72,7 +72,7 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
|
||||
model = Vits(args)
|
||||
self.assertEqual(model.embedded_speaker_dim, 101)
|
||||
self.assertEqual(model.cond_embedding_dim, 101)
|
||||
|
||||
def test_init_multilingual(self):
|
||||
args = VitsArgs(language_ids_file=None, use_language_embedding=False)
|
||||
|
@ -163,11 +163,11 @@ class TestVits(unittest.TestCase):
|
|||
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||
)
|
||||
if encoder_config:
|
||||
self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["gt_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["syn_cons_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
else:
|
||||
self.assertEqual(output_dict["gt_spk_emb"], None)
|
||||
self.assertEqual(output_dict["syn_spk_emb"], None)
|
||||
self.assertEqual(output_dict["gt_cons_emb"], None)
|
||||
self.assertEqual(output_dict["syn_cons_emb"], None)
|
||||
|
||||
def test_forward(self):
|
||||
num_speakers = 0
|
||||
|
@ -573,4 +573,4 @@ class TestVits(unittest.TestCase):
|
|||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 1)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)
|
||||
self.assertTrue(model.cond_embedding_dim == config.d_vector_dim)
|
||||
|
|
|
@ -76,8 +76,8 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
|
|||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
emotion_id = "ljspeech-3"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
continue_emotion_path = os.path.join(continue_path, "speakers.json")
|
||||
continue_speakers_path = config.model_args.d_vector_file
|
||||
continue_emotion_path = config.model_args.external_emotions_embs_file
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
|
|
Loading…
Reference in New Issue