mirror of https://github.com/coqui-ai/TTS.git
Use absolute imports for tts configs and models
This commit is contained in:
parent
82fed4add2
commit
e62d3c5cf7
|
@ -36,10 +36,11 @@ def register_config(model_name: str) -> Coqpit:
|
|||
Coqpit: config class.
|
||||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, model_name + "_config")
|
||||
config_class = find_module(path, config_name)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if config_class is None:
|
||||
|
|
|
@ -3,15 +3,15 @@ import os
|
|||
from inspect import isclass
|
||||
|
||||
# import all files under configs/
|
||||
configs_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(configs_dir):
|
||||
path = os.path.join(configs_dir, file)
|
||||
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("TTS.tts.configs." + config_name)
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
# configs_dir = os.path.dirname(__file__)
|
||||
# for file in os.listdir(configs_dir):
|
||||
# path = os.path.join(configs_dir, file)
|
||||
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
# config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
# module = importlib.import_module("TTS.tts.configs." + config_name)
|
||||
# for attribute_name in dir(module):
|
||||
# attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
# if isclass(attribute):
|
||||
# # Add the class to this package's variables
|
||||
# globals()[attribute_name] = attribute
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from TTS.tts.configs import BaseTTSConfig
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.datasets import TTSDataset
|
||||
from TTS.tts.datasets.formatters import ljspeech
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import AlignTTSConfig
|
||||
from TTS.tts.configs.align_tts_config import AlignTTSConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -4,7 +4,7 @@ import shutil
|
|||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs import FastPitchConfig
|
||||
from TTS.tts.configs.fast_pitch_config import FastPitchConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from torch import optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import SpeedySpeechConfig
|
||||
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
@ -23,7 +23,7 @@ config = Tacotron2Config(
|
|||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
use_speaker_embedding=True,
|
||||
use_speaker_embedding=False,
|
||||
use_d_vector_file=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
|||
from torch import nn, optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.shared_configs import GSTConfig
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
from TTS.tts.layers.losses import MSELossMasked
|
||||
from TTS.tts.models.tacotron2 import Tacotron2
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -114,7 +114,7 @@ class MultiSpeakerTacotronTrainTest(unittest.TestCase):
|
|||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -6,7 +6,8 @@ import torch
|
|||
from torch import nn, optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.tts.configs import GSTConfig, TacotronConfig
|
||||
from TTS.tts.configs.shared_configs import GSTConfig
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
from TTS.tts.layers.losses import L1LossMasked
|
||||
from TTS.tts.models.tacotron import Tacotron
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import TacotronConfig
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import VitsConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
|
Loading…
Reference in New Issue