Merge pull request #2187 from coqui-ai/dev-fix-vc

This commit is contained in:
Eren Gölge 2022-12-06 21:27:34 +01:00 committed by GitHub
commit 24620743ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -1211,8 +1211,8 @@ class Vits(BaseTTS):
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
# speaker embedding
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1)
g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1)
elif not self.args.use_speaker_embedding and self.args.use_d_vector_file:
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)

View File

@ -134,8 +134,8 @@ class TestVits(unittest.TestCase):
ref_inp = torch.randn(1, 513, spec_len)
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
ref_spk_id = torch.randint(1, num_speakers, (1,))
tgt_spk_id = torch.randint(1, num_speakers, (1,))
ref_spk_id = torch.randint(1, num_speakers, (1,)).item()
tgt_spk_id = torch.randint(1, num_speakers, (1,)).item()
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))