mirror of https://github.com/coqui-ai/TTS.git
Bug fix in freeze encoder (#1391)
* Fix the bug in freeze encoder * Remove emb_l definition for non-multilingual training * Fix unit tests
This commit is contained in:
parent
464dc658ff
commit
37896e1743
|
@ -706,7 +706,6 @@ class Vits(BaseTTS):
|
|||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||
else:
|
||||
self.embedded_language_dim = 0
|
||||
self.emb_l = None
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
|
|
|
@ -79,25 +79,25 @@ class TestVits(unittest.TestCase):
|
|||
model = Vits(args)
|
||||
self.assertEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, 0)
|
||||
self.assertEqual(model.emb_l, None)
|
||||
assertHasNotAttr(self, model, "emb_l")
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, 0)
|
||||
self.assertEqual(model.emb_l, None)
|
||||
assertHasNotAttr(self, model, "emb_l")
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||
self.assertNotEqual(model.emb_l, None)
|
||||
assertHasAttr(self, model, "emb_l")
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||
self.assertNotEqual(model.emb_l, None)
|
||||
assertHasAttr(self, model, "emb_l")
|
||||
|
||||
def test_get_aux_input(self):
|
||||
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
||||
|
|
Loading…
Reference in New Issue