mirror of https://github.com/coqui-ai/TTS.git
Bug fix in freeze encoder (#1391)
* Fix the bug in freeze encoder * Remove emb_l definition for non-multilingual training * Fix unit tests
This commit is contained in:
parent
464dc658ff
commit
37896e1743
|
@ -706,7 +706,6 @@ class Vits(BaseTTS):
|
||||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||||
else:
|
else:
|
||||||
self.embedded_language_dim = 0
|
self.embedded_language_dim = 0
|
||||||
self.emb_l = None
|
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
|
|
|
@ -79,25 +79,25 @@ class TestVits(unittest.TestCase):
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
self.assertEqual(model.language_manager, None)
|
self.assertEqual(model.language_manager, None)
|
||||||
self.assertEqual(model.embedded_language_dim, 0)
|
self.assertEqual(model.embedded_language_dim, 0)
|
||||||
self.assertEqual(model.emb_l, None)
|
assertHasNotAttr(self, model, "emb_l")
|
||||||
|
|
||||||
args = VitsArgs(language_ids_file=LANG_FILE)
|
args = VitsArgs(language_ids_file=LANG_FILE)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
self.assertNotEqual(model.language_manager, None)
|
self.assertNotEqual(model.language_manager, None)
|
||||||
self.assertEqual(model.embedded_language_dim, 0)
|
self.assertEqual(model.embedded_language_dim, 0)
|
||||||
self.assertEqual(model.emb_l, None)
|
assertHasNotAttr(self, model, "emb_l")
|
||||||
|
|
||||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
self.assertNotEqual(model.language_manager, None)
|
self.assertNotEqual(model.language_manager, None)
|
||||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||||
self.assertNotEqual(model.emb_l, None)
|
assertHasAttr(self, model, "emb_l")
|
||||||
|
|
||||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
self.assertNotEqual(model.language_manager, None)
|
self.assertNotEqual(model.language_manager, None)
|
||||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||||
self.assertNotEqual(model.emb_l, None)
|
assertHasAttr(self, model, "emb_l")
|
||||||
|
|
||||||
def test_get_aux_input(self):
|
def test_get_aux_input(self):
|
||||||
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
||||||
|
|
Loading…
Reference in New Issue