From 9c1bec86a4631adaeb3295fc4b779f62c8ba1fca Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 23 Nov 2021 16:00:38 +0100 Subject: [PATCH] Fix tests --- TTS/tts/models/vits.py | 2 +- .../test_vits_multilingual_train-d_vectors.py | 10 +++++----- tests/tts_tests/test_vits_multilingual_train.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cc86e119..1b6d29d4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -596,7 +596,7 @@ class Vits(BaseTTS): # language embedding lang_emb = None - if hasattr(self, "emb_l"): + if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index f426e383..0e9827f1 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -10,7 +10,7 @@ config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") -dataset_config1 = BaseDatasetConfig( +dataset_config_en = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", @@ -18,12 +18,12 @@ dataset_config1 = BaseDatasetConfig( language="en", ) -dataset_config2 = BaseDatasetConfig( +dataset_config_pt = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", - language="en2", + language="pt-br", ) config = VitsConfig( @@ -43,9 +43,9 @@ config = VitsConfig( print_eval=True, test_sentences=[ ["Be a voice, not an echo.", "ljspeech-0", None, "en"], - ["Be a voice, not an echo.", "ljspeech-1", None, "en2"], + ["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"], ], - datasets=[dataset_config1, dataset_config2], + datasets=[dataset_config_en, dataset_config_pt], ) # set audio config config.audio.do_trim_silence = True diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_train.py index 90f589d0..50cccca5 100644 --- a/tests/tts_tests/test_vits_multilingual_train.py +++ b/tests/tts_tests/test_vits_multilingual_train.py @@ -10,7 +10,7 @@ config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") -dataset_config1 = BaseDatasetConfig( +dataset_config_en = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", @@ -18,12 +18,12 @@ dataset_config1 = BaseDatasetConfig( language="en", ) -dataset_config2 = BaseDatasetConfig( +dataset_config_pt = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", - language="en2", + language="pt-br", ) config = VitsConfig( @@ -43,9 +43,9 @@ config = VitsConfig( print_eval=True, test_sentences=[ ["Be a voice, not an echo.", "ljspeech", None, "en"], - ["Be a voice, not an echo.", "ljspeech", None, "en2"], + ["Be a voice, not an echo.", "ljspeech", None, "pt-br"], ], - datasets=[dataset_config1, dataset_config2], + datasets=[dataset_config_en, dataset_config_pt], ) # set audio config config.audio.do_trim_silence = True