From e62d3c5cf7031e8f8113aca68f35508db8453f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 21 Oct 2021 16:08:13 +0000 Subject: [PATCH] Use absolute imports for tts configs and models --- TTS/config/__init__.py | 3 ++- TTS/tts/configs/__init__.py | 22 +++++++++---------- tests/data_tests/test_loader.py | 2 +- tests/tts_tests/test_align_tts_train.py | 2 +- tests/tts_tests/test_fast_pitch_train.py | 2 +- tests/tts_tests/test_glow_tts.py | 2 +- tests/tts_tests/test_glow_tts_train.py | 2 +- tests/tts_tests/test_speedy_speech_train.py | 2 +- .../test_tacotron2_d-vectors_train.py | 4 ++-- tests/tts_tests/test_tacotron2_model.py | 4 ++-- .../test_tacotron2_speaker_emb_train.py | 2 +- tests/tts_tests/test_tacotron2_tf_model.py | 2 +- tests/tts_tests/test_tacotron2_train.py | 2 +- .../test_tacotron2_train_fsspec_path.py | 2 +- tests/tts_tests/test_tacotron_model.py | 3 ++- tests/tts_tests/test_tacotron_train.py | 2 +- tests/tts_tests/test_vits_train.py | 2 +- 17 files changed, 31 insertions(+), 29 deletions(-) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index ea98f431..f626163f 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -36,10 +36,11 @@ def register_config(model_name: str) -> Coqpit: Coqpit: config class. """ config_class = None + config_name = model_name + "_config" paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"] for path in paths: try: - config_class = find_module(path, model_name + "_config") + config_class = find_module(path, config_name) except ModuleNotFoundError: pass if config_class is None: diff --git a/TTS/tts/configs/__init__.py b/TTS/tts/configs/__init__.py index 5ad4fe8c..3146ac1c 100644 --- a/TTS/tts/configs/__init__.py +++ b/TTS/tts/configs/__init__.py @@ -3,15 +3,15 @@ import os from inspect import isclass # import all files under configs/ -configs_dir = os.path.dirname(__file__) -for file in os.listdir(configs_dir): - path = os.path.join(configs_dir, file) - if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): - config_name = file[: file.find(".py")] if file.endswith(".py") else file - module = importlib.import_module("TTS.tts.configs." + config_name) - for attribute_name in dir(module): - attribute = getattr(module, attribute_name) +# configs_dir = os.path.dirname(__file__) +# for file in os.listdir(configs_dir): +# path = os.path.join(configs_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# config_name = file[: file.find(".py")] if file.endswith(".py") else file +# module = importlib.import_module("TTS.tts.configs." + config_name) +# for attribute_name in dir(module): +# attribute = getattr(module, attribute_name) - if isclass(attribute): - # Add the class to this package's variables - globals()[attribute_name] = attribute +# if isclass(attribute): +# # Add the class to this package's variables +# globals()[attribute_name] = attribute diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 18066ef3..8a20c261 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs import BaseTTSConfig +from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.datasets import TTSDataset from TTS.tts.datasets.formatters import ljspeech from TTS.utils.audio import AudioProcessor diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index f04a2358..f5d60d7c 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import AlignTTSConfig +from TTS.tts.configs.align_tts_config import AlignTTSConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 89176ac9..71ba8b25 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -4,7 +4,7 @@ import shutil from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig -from TTS.tts.configs import FastPitchConfig +from TTS.tts.configs.fast_pitch_config import FastPitchConfig config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index e139562c..82d0ec3b 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -6,7 +6,7 @@ import torch from torch import optim from tests import get_tests_input_path -from TTS.tts.configs import GlowTTSConfig +from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS from TTS.utils.audio import AudioProcessor diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index 7da4fd33..e5901076 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import GlowTTSConfig +from TTS.tts.configs.glow_tts_config import GlowTTSConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index a181ac24..7d7ecc7c 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import SpeedySpeechConfig +from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index 1a8d78bf..c817badc 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import Tacotron2Config +from TTS.tts.configs.tacotron2_config import Tacotron2Config config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -23,7 +23,7 @@ config = Tacotron2Config( epochs=1, print_step=1, print_eval=True, - use_speaker_embedding=True, + use_speaker_embedding=False, use_d_vector_file=True, test_sentences=[ "Be a voice, not an echo.", diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index 65d2bd9d..df184a6a 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -6,8 +6,8 @@ import torch from torch import nn, optim from tests import get_tests_input_path -from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.shared_configs import GSTConfig +from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.layers.losses import MSELossMasked from TTS.tts.models.tacotron2 import Tacotron2 from TTS.utils.audio import AudioProcessor @@ -114,7 +114,7 @@ class MultiSpeakerTacotronTrainTest(unittest.TestCase): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=config.lr) - for i in range(5): + for _ in range(5): outputs = model.forward( input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids} ) diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 41d694f6..095016d8 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import Tacotron2Config +from TTS.tts.configs.tacotron2_config import Tacotron2Config config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron2_tf_model.py b/tests/tts_tests/test_tacotron2_tf_model.py index 515a6834..fb1efcde 100644 --- a/tests/tts_tests/test_tacotron2_tf_model.py +++ b/tests/tts_tests/test_tacotron2_tf_model.py @@ -5,7 +5,7 @@ import numpy as np import tensorflow as tf import torch -from TTS.tts.configs import Tacotron2Config +from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index e947a54a..4f37ef89 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import Tacotron2Config +from TTS.tts.configs.tacotron2_config import Tacotron2Config config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron2_train_fsspec_path.py b/tests/tts_tests/test_tacotron2_train_fsspec_path.py index 9e4ee102..5d14a983 100644 --- a/tests/tts_tests/test_tacotron2_train_fsspec_path.py +++ b/tests/tts_tests/test_tacotron2_train_fsspec_path.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import Tacotron2Config +from TTS.tts.configs.tacotron2_config import Tacotron2Config config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 3f570276..6e0e712b 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -6,7 +6,8 @@ import torch from torch import nn, optim from tests import get_tests_input_path -from TTS.tts.configs import GSTConfig, TacotronConfig +from TTS.tts.configs.shared_configs import GSTConfig +from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.layers.losses import L1LossMasked from TTS.tts.models.tacotron import Tacotron from TTS.utils.audio import AudioProcessor diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 0c35ee28..68071c66 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import TacotronConfig +from TTS.tts.configs.tacotron_config import TacotronConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index db9d2fc1..6398955e 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import VitsConfig +from TTS.tts.configs.vits_config import VitsConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs")