Fix unit tests

This commit is contained in:
Edresson Casanova 2022-05-23 10:26:45 -03:00
parent 749b217884
commit 6e4b13c6cc
4 changed files with 23 additions and 7 deletions

View File

@ -259,15 +259,15 @@ class Tacotron2(BaseTacotron):
encoder_outputs.device
) # pylint: disable=not-callable
reference_mel_length = (
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_mel"] is not None
torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_feature"] is not None
else None
) # pylint: disable=not-callable
# B x capacitron_VAE_embedding_dim
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
if aux_input["style_mel"] is not None
reference_mel_info=[aux_input["style_feature"], reference_mel_length]
if aux_input["style_feature"] is not None
else None,
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
speaker_embedding=aux_input["d_vectors"]

View File

@ -884,7 +884,7 @@ class Vits(BaseTTS):
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
return {
"speaker_ids": sid,
"style_wav": None,
@ -892,6 +892,7 @@ class Vits(BaseTTS):
"language_ids": lid,
"emotion_embeddings": eg,
"emotion_ids": eid,
"style_feature": pf,
}
def _freeze_layers(self):

View File

@ -107,12 +107,27 @@ class TestVits(unittest.TestCase):
speaker_id = torch.randint(10, (1,))
language_id = torch.randint(10, (1,))
emotion_id = torch.randint(10, (1,))
emotion_embedding = torch.rand(1, 128)
d_vector = torch.rand(1, 128)
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
style_feature = torch.rand(10, 128, 64)
aux_input = {
"speaker_ids": speaker_id,
"style_wav": None,
"d_vectors": d_vector,
"language_ids": language_id,
"style_feature": style_feature,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
}
aux_out = model.get_aux_input(aux_input)
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape)
self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape)
self.assertEqual(aux_out["style_feature"].shape, style_feature.shape)
def test_voice_conversion(self):
num_speakers = 10

View File

@ -76,7 +76,7 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-3"
continue_speakers_path = config.d_vector_file
continue_speakers_path = config.model_args.d_vector_file
continue_emotion_path = os.path.join(continue_path, "emotions.json")