mirror of https://github.com/coqui-ai/TTS.git
Fix unit tests
This commit is contained in:
parent
749b217884
commit
6e4b13c6cc
|
@ -259,15 +259,15 @@ class Tacotron2(BaseTacotron):
|
|||
encoder_outputs.device
|
||||
) # pylint: disable=not-callable
|
||||
reference_mel_length = (
|
||||
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_mel"] is not None
|
||||
torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_feature"] is not None
|
||||
else None
|
||||
) # pylint: disable=not-callable
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
|
||||
if aux_input["style_mel"] is not None
|
||||
reference_mel_info=[aux_input["style_feature"], reference_mel_length]
|
||||
if aux_input["style_feature"] is not None
|
||||
else None,
|
||||
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
|
||||
speaker_embedding=aux_input["d_vectors"]
|
||||
|
|
|
@ -884,7 +884,7 @@ class Vits(BaseTTS):
|
|||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
|
||||
return {
|
||||
"speaker_ids": sid,
|
||||
"style_wav": None,
|
||||
|
@ -892,6 +892,7 @@ class Vits(BaseTTS):
|
|||
"language_ids": lid,
|
||||
"emotion_embeddings": eg,
|
||||
"emotion_ids": eid,
|
||||
"style_feature": pf,
|
||||
}
|
||||
|
||||
def _freeze_layers(self):
|
||||
|
|
|
@ -107,12 +107,27 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
speaker_id = torch.randint(10, (1,))
|
||||
language_id = torch.randint(10, (1,))
|
||||
emotion_id = torch.randint(10, (1,))
|
||||
emotion_embedding = torch.rand(1, 128)
|
||||
d_vector = torch.rand(1, 128)
|
||||
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
|
||||
style_feature = torch.rand(10, 128, 64)
|
||||
|
||||
aux_input = {
|
||||
"speaker_ids": speaker_id,
|
||||
"style_wav": None,
|
||||
"d_vectors": d_vector,
|
||||
"language_ids": language_id,
|
||||
"style_feature": style_feature,
|
||||
"emotion_ids": emotion_id,
|
||||
"emotion_embeddings": emotion_embedding,
|
||||
}
|
||||
aux_out = model.get_aux_input(aux_input)
|
||||
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
||||
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
|
||||
self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
|
||||
self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape)
|
||||
self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape)
|
||||
self.assertEqual(aux_out["style_feature"].shape, style_feature.shape)
|
||||
|
||||
def test_voice_conversion(self):
|
||||
num_speakers = 10
|
||||
|
|
|
@ -76,7 +76,7 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
|
|||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
emotion_id = "ljspeech-3"
|
||||
continue_speakers_path = config.d_vector_file
|
||||
continue_speakers_path = config.model_args.d_vector_file
|
||||
continue_emotion_path = os.path.join(continue_path, "emotions.json")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue