From 37896e17430a5627b4b3224603b9101f3259a446 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 24 Mar 2022 14:16:04 -0300 Subject: [PATCH] Bug fix in freeze encoder (#1391) * Fix the bug in freeze encoder * Remove emb_l definition for non-multilingual training * Fix unit tests --- TTS/tts/models/vits.py | 1 - tests/tts_tests/test_vits.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index afadbadd..87d559fc 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -706,7 +706,6 @@ class Vits(BaseTTS): torch.nn.init.xavier_uniform_(self.emb_l.weight) else: self.embedded_language_dim = 0 - self.emb_l = None def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 81d2ebbd..05adb9ed 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -79,25 +79,25 @@ class TestVits(unittest.TestCase): model = Vits(args) self.assertEqual(model.language_manager, None) self.assertEqual(model.embedded_language_dim, 0) - self.assertEqual(model.emb_l, None) + assertHasNotAttr(self, model, "emb_l") args = VitsArgs(language_ids_file=LANG_FILE) model = Vits(args) self.assertNotEqual(model.language_manager, None) self.assertEqual(model.embedded_language_dim, 0) - self.assertEqual(model.emb_l, None) + assertHasNotAttr(self, model, "emb_l") args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True) model = Vits(args) self.assertNotEqual(model.language_manager, None) self.assertEqual(model.embedded_language_dim, args.embedded_language_dim) - self.assertNotEqual(model.emb_l, None) + assertHasAttr(self, model, "emb_l") args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102) model = Vits(args) self.assertNotEqual(model.language_manager, None) self.assertEqual(model.embedded_language_dim, args.embedded_language_dim) - self.assertNotEqual(model.emb_l, None) + assertHasAttr(self, model, "emb_l") def test_get_aux_input(self): aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}