From ee20e309583d5c39a99b58c982127ea1f7256de9 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 5 Dec 2022 09:15:01 -0300 Subject: [PATCH 1/2] Fix VITS multi-speaker voice conversion inference --- TTS/tts/models/vits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9dfdc067..518809b3 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1211,8 +1211,8 @@ class Vits(BaseTTS): assert self.num_speakers > 0, "num_speakers have to be larger than 0." # speaker embedding if self.args.use_speaker_embedding and not self.args.use_d_vector_file: - g_src = self.emb_g(speaker_cond_src).unsqueeze(-1) - g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1) + g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) + g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: g_src = F.normalize(speaker_cond_src).unsqueeze(-1) g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) From d2460de94b98c344f21a21ecc2f88c9385bea6f4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 5 Dec 2022 09:59:11 -0300 Subject: [PATCH 2/2] Fix unit tests --- tests/tts_tests/test_vits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 7d474c20..ccc3be1c 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -134,8 +134,8 @@ class TestVits(unittest.TestCase): ref_inp = torch.randn(1, 513, spec_len) ref_inp_len = torch.randint(1, spec_effective_len, (1,)) - ref_spk_id = torch.randint(1, num_speakers, (1,)) - tgt_spk_id = torch.randint(1, num_speakers, (1,)) + ref_spk_id = torch.randint(1, num_speakers, (1,)).item() + tgt_spk_id = torch.randint(1, num_speakers, (1,)).item() o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id) self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))