Use absolute imports for tts configs and models

This commit is contained in:
Eren Gölge 2021-10-21 16:08:13 +00:00
parent 82fed4add2
commit e62d3c5cf7
17 changed files with 31 additions and 29 deletions

View File

@ -36,10 +36,11 @@ def register_config(model_name: str) -> Coqpit:
Coqpit: config class. Coqpit: config class.
""" """
config_class = None config_class = None
config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"] paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
for path in paths: for path in paths:
try: try:
config_class = find_module(path, model_name + "_config") config_class = find_module(path, config_name)
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
if config_class is None: if config_class is None:

View File

@ -3,15 +3,15 @@ import os
from inspect import isclass from inspect import isclass
# import all files under configs/ # import all files under configs/
configs_dir = os.path.dirname(__file__) # configs_dir = os.path.dirname(__file__)
for file in os.listdir(configs_dir): # for file in os.listdir(configs_dir):
path = os.path.join(configs_dir, file) # path = os.path.join(configs_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): # if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
config_name = file[: file.find(".py")] if file.endswith(".py") else file # config_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("TTS.tts.configs." + config_name) # module = importlib.import_module("TTS.tts.configs." + config_name)
for attribute_name in dir(module): # for attribute_name in dir(module):
attribute = getattr(module, attribute_name) # attribute = getattr(module, attribute_name)
if isclass(attribute): # if isclass(attribute):
# Add the class to this package's variables # # Add the class to this package's variables
globals()[attribute_name] = attribute # globals()[attribute_name] = attribute

View File

@ -7,7 +7,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tests import get_tests_output_path from tests import get_tests_output_path
from TTS.tts.configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.datasets import TTSDataset from TTS.tts.datasets import TTSDataset
from TTS.tts.datasets.formatters import ljspeech from TTS.tts.datasets.formatters import ljspeech
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import AlignTTSConfig from TTS.tts.configs.align_tts_config import AlignTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -4,7 +4,7 @@ import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs import FastPitchConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -6,7 +6,7 @@ import torch
from torch import optim from torch import optim
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.tts.configs import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import SpeedySpeechConfig from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json") config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")
@ -23,7 +23,7 @@ config = Tacotron2Config(
epochs=1, epochs=1,
print_step=1, print_step=1,
print_eval=True, print_eval=True,
use_speaker_embedding=True, use_speaker_embedding=False,
use_d_vector_file=True, use_d_vector_file=True,
test_sentences=[ test_sentences=[
"Be a voice, not an echo.", "Be a voice, not an echo.",

View File

@ -6,8 +6,8 @@ import torch
from torch import nn, optim from torch import nn, optim
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.shared_configs import GSTConfig from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.layers.losses import MSELossMasked from TTS.tts.layers.losses import MSELossMasked
from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.models.tacotron2 import Tacotron2
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -114,7 +114,7 @@ class MultiSpeakerTacotronTrainTest(unittest.TestCase):
assert (param - param_ref).sum() == 0, param assert (param - param_ref).sum() == 0, param
count += 1 count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr) optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5): for _ in range(5):
outputs = model.forward( outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids} input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
) )

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -5,7 +5,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -6,7 +6,8 @@ import torch
from torch import nn, optim from torch import nn, optim
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.tts.configs import GSTConfig, TacotronConfig from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.layers.losses import L1LossMasked from TTS.tts.layers.losses import L1LossMasked
from TTS.tts.models.tacotron import Tacotron from TTS.tts.models.tacotron import Tacotron
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import TacotronConfig from TTS.tts.configs.tacotron_config import TacotronConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")