Use absolute imports for tts configs and models

This commit is contained in:
Eren Gölge 2021-10-21 16:08:13 +00:00
parent 82fed4add2
commit e62d3c5cf7
17 changed files with 31 additions and 29 deletions

View File

@ -36,10 +36,11 @@ def register_config(model_name: str) -> Coqpit:
Coqpit: config class.
"""
config_class = None
config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
for path in paths:
try:
config_class = find_module(path, model_name + "_config")
config_class = find_module(path, config_name)
except ModuleNotFoundError:
pass
if config_class is None:

View File

@ -3,15 +3,15 @@ import os
from inspect import isclass
# import all files under configs/
configs_dir = os.path.dirname(__file__)
for file in os.listdir(configs_dir):
path = os.path.join(configs_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
config_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("TTS.tts.configs." + config_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
# configs_dir = os.path.dirname(__file__)
# for file in os.listdir(configs_dir):
# path = os.path.join(configs_dir, file)
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
# config_name = file[: file.find(".py")] if file.endswith(".py") else file
# module = importlib.import_module("TTS.tts.configs." + config_name)
# for attribute_name in dir(module):
# attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
# if isclass(attribute):
# # Add the class to this package's variables
# globals()[attribute_name] = attribute

View File

@ -7,7 +7,7 @@ import torch
from torch.utils.data import DataLoader
from tests import get_tests_output_path
from TTS.tts.configs import BaseTTSConfig
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.datasets import TTSDataset
from TTS.tts.datasets.formatters import ljspeech
from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import AlignTTSConfig
from TTS.tts.configs.align_tts_config import AlignTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -4,7 +4,7 @@ import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs import FastPitchConfig
from TTS.tts.configs.fast_pitch_config import FastPitchConfig
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -6,7 +6,7 @@ import torch
from torch import optim
from tests import get_tests_input_path
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import SpeedySpeechConfig
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
@ -23,7 +23,7 @@ config = Tacotron2Config(
epochs=1,
print_step=1,
print_eval=True,
use_speaker_embedding=True,
use_speaker_embedding=False,
use_d_vector_file=True,
test_sentences=[
"Be a voice, not an echo.",

View File

@ -6,8 +6,8 @@ import torch
from torch import nn, optim
from tests import get_tests_input_path
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.layers.losses import MSELossMasked
from TTS.tts.models.tacotron2 import Tacotron2
from TTS.utils.audio import AudioProcessor
@ -114,7 +114,7 @@ class MultiSpeakerTacotronTrainTest(unittest.TestCase):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5):
for _ in range(5):
outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
)

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -5,7 +5,7 @@ import numpy as np
import tensorflow as tf
import torch
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
from TTS.tts.configs.tacotron2_config import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -6,7 +6,8 @@ import torch
from torch import nn, optim
from tests import get_tests_input_path
from TTS.tts.configs import GSTConfig, TacotronConfig
from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.layers.losses import L1LossMasked
from TTS.tts.models.tacotron import Tacotron
from TTS.utils.audio import AudioProcessor

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import TacotronConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

View File

@ -3,7 +3,7 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")