Fix unit tests

This commit is contained in:
Edresson Casanova 2022-06-08 10:18:19 -03:00
parent 4b59f07946
commit 0844d9225d
3 changed files with 15 additions and 5 deletions

View File

@ -972,7 +972,7 @@ class Vits(BaseTTS):
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
return {
"speaker_ids": sid,
"style_wav": None,
@ -981,6 +981,8 @@ class Vits(BaseTTS):
"emotion_embeddings": eg,
"emotion_ids": eid,
"style_feature": pf,
"style_speaker_ids": ssid,
"style_speaker_d_vectors": ssg,
}
def _freeze_layers(self):

View File

@ -111,6 +111,8 @@ class TestVits(unittest.TestCase):
emotion_embedding = torch.rand(1, 128)
d_vector = torch.rand(1, 128)
style_feature = torch.rand(10, 128, 64)
style_speaker_id = torch.randint(10, (1,))
style_speaker_d_vector = torch.rand(1, 128)
aux_input = {
"speaker_ids": speaker_id,
@ -120,6 +122,8 @@ class TestVits(unittest.TestCase):
"style_feature": style_feature,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
"style_speaker_id": style_speaker_id,
"style_speaker_d_vector": style_speaker_d_vector,
}
aux_out = model.get_aux_input(aux_input)
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
@ -128,6 +132,10 @@ class TestVits(unittest.TestCase):
self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape)
self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape)
self.assertEqual(aux_out["style_feature"].shape, style_feature.shape)
self.assertEqual(aux_out["style_speaker_ids"].shape, style_speaker_id.shape)
self.assertEqual(
aux_out["style_speaker_d_vectors"].shape, style_speaker_d_vector.unsqueeze(0).transpose(2, 1).shape
)
def test_voice_conversion(self):
num_speakers = 10

View File

@ -38,16 +38,16 @@ config.model_args.use_speaker_embedding = False
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256
config.model_args.d_vector_dim = 2
# emotion
config.model_args.emotion_embedding_dim = 256
config.model_args.emotion_embedding_dim = 2
config.model_args.use_emotion_embedding_squeezer = False
config.model_args.use_emotion_embedding_squeezer = True
config.model_args.emotion_embedding_squeezer_input_dim = 256
config.model_args.use_speaker_embedding_as_emotion = True
config.model_args.use_speaker_embedding_squeezer = False
config.model_args.use_speaker_embedding_squeezer = True
config.model_args.speaker_embedding_squeezer_input_dim = 256
# consistency loss