diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 5be47885..3e33e271 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -972,7 +972,7 @@ class Vits(BaseTTS): self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) def get_aux_input(self, aux_input: Dict): - sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input) return { "speaker_ids": sid, "style_wav": None, @@ -981,6 +981,8 @@ class Vits(BaseTTS): "emotion_embeddings": eg, "emotion_ids": eid, "style_feature": pf, + "style_speaker_ids": ssid, + "style_speaker_d_vectors": ssg, } def _freeze_layers(self): diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index e0a753ea..0312e0ec 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -111,6 +111,8 @@ class TestVits(unittest.TestCase): emotion_embedding = torch.rand(1, 128) d_vector = torch.rand(1, 128) style_feature = torch.rand(10, 128, 64) + style_speaker_id = torch.randint(10, (1,)) + style_speaker_d_vector = torch.rand(1, 128) aux_input = { "speaker_ids": speaker_id, @@ -120,6 +122,8 @@ class TestVits(unittest.TestCase): "style_feature": style_feature, "emotion_ids": emotion_id, "emotion_embeddings": emotion_embedding, + "style_speaker_id": style_speaker_id, + "style_speaker_d_vector": style_speaker_d_vector, } aux_out = model.get_aux_input(aux_input) self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) @@ -128,6 +132,10 @@ class TestVits(unittest.TestCase): self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape) self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape) self.assertEqual(aux_out["style_feature"].shape, style_feature.shape) + self.assertEqual(aux_out["style_speaker_ids"].shape, style_speaker_id.shape) + self.assertEqual( + aux_out["style_speaker_d_vectors"].shape, style_speaker_d_vector.unsqueeze(0).transpose(2, 1).shape + ) def test_voice_conversion(self): num_speakers = 10 diff --git a/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py b/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py index ac157c58..29045cac 100644 --- a/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py +++ b/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py @@ -38,16 +38,16 @@ config.model_args.use_speaker_embedding = False config.model_args.use_d_vector_file = True config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.speaker_embedding_channels = 128 -config.model_args.d_vector_dim = 256 +config.model_args.d_vector_dim = 2 # emotion -config.model_args.emotion_embedding_dim = 256 +config.model_args.emotion_embedding_dim = 2 -config.model_args.use_emotion_embedding_squeezer = False +config.model_args.use_emotion_embedding_squeezer = True config.model_args.emotion_embedding_squeezer_input_dim = 256 config.model_args.use_speaker_embedding_as_emotion = True -config.model_args.use_speaker_embedding_squeezer = False +config.model_args.use_speaker_embedding_squeezer = True config.model_args.speaker_embedding_squeezer_input_dim = 256 # consistency loss