Bug fix in freeze encoder (#1391)

* Fix the bug in freeze encoder

* Remove emb_l definition for non-multilingual training

* Fix unit tests
This commit is contained in:
Edresson Casanova 2022-03-24 14:16:04 -03:00 committed by GitHub
parent 464dc658ff
commit 37896e1743
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 5 deletions

View File

@ -706,7 +706,6 @@ class Vits(BaseTTS):
torch.nn.init.xavier_uniform_(self.emb_l.weight)
else:
self.embedded_language_dim = 0
self.emb_l = None
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)

View File

@ -79,25 +79,25 @@ class TestVits(unittest.TestCase):
model = Vits(args)
self.assertEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
assertHasNotAttr(self, model, "emb_l")
args = VitsArgs(language_ids_file=LANG_FILE)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
assertHasNotAttr(self, model, "emb_l")
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
assertHasAttr(self, model, "emb_l")
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
assertHasAttr(self, model, "emb_l")
def test_get_aux_input(self):
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}