mirror of https://github.com/coqui-ai/TTS.git
Fix unit tests
This commit is contained in:
parent
4b59f07946
commit
0844d9225d
|
@ -972,7 +972,7 @@ class Vits(BaseTTS):
|
|||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
|
||||
sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
|
||||
return {
|
||||
"speaker_ids": sid,
|
||||
"style_wav": None,
|
||||
|
@ -981,6 +981,8 @@ class Vits(BaseTTS):
|
|||
"emotion_embeddings": eg,
|
||||
"emotion_ids": eid,
|
||||
"style_feature": pf,
|
||||
"style_speaker_ids": ssid,
|
||||
"style_speaker_d_vectors": ssg,
|
||||
}
|
||||
|
||||
def _freeze_layers(self):
|
||||
|
|
|
@ -111,6 +111,8 @@ class TestVits(unittest.TestCase):
|
|||
emotion_embedding = torch.rand(1, 128)
|
||||
d_vector = torch.rand(1, 128)
|
||||
style_feature = torch.rand(10, 128, 64)
|
||||
style_speaker_id = torch.randint(10, (1,))
|
||||
style_speaker_d_vector = torch.rand(1, 128)
|
||||
|
||||
aux_input = {
|
||||
"speaker_ids": speaker_id,
|
||||
|
@ -120,6 +122,8 @@ class TestVits(unittest.TestCase):
|
|||
"style_feature": style_feature,
|
||||
"emotion_ids": emotion_id,
|
||||
"emotion_embeddings": emotion_embedding,
|
||||
"style_speaker_id": style_speaker_id,
|
||||
"style_speaker_d_vector": style_speaker_d_vector,
|
||||
}
|
||||
aux_out = model.get_aux_input(aux_input)
|
||||
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
||||
|
@ -128,6 +132,10 @@ class TestVits(unittest.TestCase):
|
|||
self.assertEqual(aux_out["emotion_ids"].shape, emotion_id.shape)
|
||||
self.assertEqual(aux_out["emotion_embeddings"].shape, emotion_embedding.unsqueeze(0).transpose(2, 1).shape)
|
||||
self.assertEqual(aux_out["style_feature"].shape, style_feature.shape)
|
||||
self.assertEqual(aux_out["style_speaker_ids"].shape, style_speaker_id.shape)
|
||||
self.assertEqual(
|
||||
aux_out["style_speaker_d_vectors"].shape, style_speaker_d_vector.unsqueeze(0).transpose(2, 1).shape
|
||||
)
|
||||
|
||||
def test_voice_conversion(self):
|
||||
num_speakers = 10
|
||||
|
|
|
@ -38,16 +38,16 @@ config.model_args.use_speaker_embedding = False
|
|||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 256
|
||||
config.model_args.d_vector_dim = 2
|
||||
|
||||
# emotion
|
||||
config.model_args.emotion_embedding_dim = 256
|
||||
config.model_args.emotion_embedding_dim = 2
|
||||
|
||||
config.model_args.use_emotion_embedding_squeezer = False
|
||||
config.model_args.use_emotion_embedding_squeezer = True
|
||||
config.model_args.emotion_embedding_squeezer_input_dim = 256
|
||||
config.model_args.use_speaker_embedding_as_emotion = True
|
||||
|
||||
config.model_args.use_speaker_embedding_squeezer = False
|
||||
config.model_args.use_speaker_embedding_squeezer = True
|
||||
config.model_args.speaker_embedding_squeezer_input_dim = 256
|
||||
|
||||
# consistency loss
|
||||
|
|
Loading…
Reference in New Issue