diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 2b1faf4d..8db0dbc3 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -259,15 +259,15 @@ class Tacotron2(BaseTacotron): encoder_outputs.device ) # pylint: disable=not-callable reference_mel_length = ( - torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) - if aux_input["style_mel"] is not None + torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_feature"] is not None else None ) # pylint: disable=not-callable # B x capacitron_VAE_embedding_dim encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( encoder_outputs, - reference_mel_info=[aux_input["style_mel"], reference_mel_length] - if aux_input["style_mel"] is not None + reference_mel_info=[aux_input["style_feature"], reference_mel_length] + if aux_input["style_feature"] is not None else None, text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, speaker_embedding=aux_input["d_vectors"] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cbe310a9..4d385a2a 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -884,7 +884,7 @@ class Vits(BaseTTS): self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) def get_aux_input(self, aux_input: Dict): - sid, g, lid, eid, eg = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input) return { "speaker_ids": sid, "style_wav": None, @@ -892,6 +892,7 @@ class Vits(BaseTTS): "language_ids": lid, "emotion_embeddings": eg, "emotion_ids": eid, + "style_feature": pf, } def _freeze_layers(self): diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index fc285007..e0a753ea 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -107,12 +107,27 @@ class TestVits(unittest.TestCase): speaker_id = torch.randint(10, (1,)) language_id = torch.randint(10, (1,)) + emotion_id = torch.randint(10, (1,)) + emotion_embedding = torch.rand(1, 128) d_vector = torch.rand(1, 128) - aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} + style_feature = torch.rand(10, 128, 64) + + aux_input = { + "speaker_ids": speaker_id, + "style_wav": None, + "d_vectors": d_vector, + "language_ids": language_id, + "style_feature": style_feature, + "emotion_ids": emotion_id, + "emotion_embeddings": emotion_embedding, + } aux_out = model.get_aux_input(aux_input) self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) self.assertEqual(aux_out["language_ids"].shape, language_id.shape) self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape) + self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape) + self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape) + self.assertEqual(aux_out["style_feature"].shape, style_feature.shape) def test_voice_conversion(self): num_speakers = 10 diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index 4856c364..4b67a339 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -76,7 +76,7 @@ continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" emotion_id = "ljspeech-3" -continue_speakers_path = config.d_vector_file +continue_speakers_path = config.model_args.d_vector_file continue_emotion_path = os.path.join(continue_path, "emotions.json")