mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #2187 from coqui-ai/dev-fix-vc
This commit is contained in:
commit
24620743ca
|
@ -1211,8 +1211,8 @@ class Vits(BaseTTS):
|
|||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
||||
g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1)
|
||||
elif not self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
||||
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||
|
|
|
@ -134,8 +134,8 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
ref_inp = torch.randn(1, 513, spec_len)
|
||||
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
|
||||
ref_spk_id = torch.randint(1, num_speakers, (1,))
|
||||
tgt_spk_id = torch.randint(1, num_speakers, (1,))
|
||||
ref_spk_id = torch.randint(1, num_speakers, (1,)).item()
|
||||
tgt_spk_id = torch.randint(1, num_speakers, (1,)).item()
|
||||
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
|
||||
|
||||
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
|
||||
|
|
Loading…
Reference in New Issue