mirror of https://github.com/coqui-ai/TTS.git
Use absolute imports for tts configs and models
This commit is contained in:
parent
82fed4add2
commit
e62d3c5cf7
|
@ -36,10 +36,11 @@ def register_config(model_name: str) -> Coqpit:
|
||||||
Coqpit: config class.
|
Coqpit: config class.
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
|
config_name = model_name + "_config"
|
||||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
config_class = find_module(path, model_name + "_config")
|
config_class = find_module(path, config_name)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
if config_class is None:
|
if config_class is None:
|
||||||
|
|
|
@ -3,15 +3,15 @@ import os
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
|
|
||||||
# import all files under configs/
|
# import all files under configs/
|
||||||
configs_dir = os.path.dirname(__file__)
|
# configs_dir = os.path.dirname(__file__)
|
||||||
for file in os.listdir(configs_dir):
|
# for file in os.listdir(configs_dir):
|
||||||
path = os.path.join(configs_dir, file)
|
# path = os.path.join(configs_dir, file)
|
||||||
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||||
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
# config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||||
module = importlib.import_module("TTS.tts.configs." + config_name)
|
# module = importlib.import_module("TTS.tts.configs." + config_name)
|
||||||
for attribute_name in dir(module):
|
# for attribute_name in dir(module):
|
||||||
attribute = getattr(module, attribute_name)
|
# attribute = getattr(module, attribute_name)
|
||||||
|
|
||||||
if isclass(attribute):
|
# if isclass(attribute):
|
||||||
# Add the class to this package's variables
|
# # Add the class to this package's variables
|
||||||
globals()[attribute_name] = attribute
|
# globals()[attribute_name] = attribute
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from tests import get_tests_output_path
|
from tests import get_tests_output_path
|
||||||
from TTS.tts.configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.datasets import TTSDataset
|
from TTS.tts.datasets import TTSDataset
|
||||||
from TTS.tts.datasets.formatters import ljspeech
|
from TTS.tts.datasets.formatters import ljspeech
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import AlignTTSConfig
|
from TTS.tts.configs.align_tts_config import AlignTTSConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -4,7 +4,7 @@ import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
from TTS.tts.configs import FastPitchConfig
|
from TTS.tts.configs.fast_pitch_config import FastPitchConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.tts.configs import GlowTTSConfig
|
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.models.glow_tts import GlowTTS
|
from TTS.tts.models.glow_tts import GlowTTS
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import GlowTTSConfig
|
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import SpeedySpeechConfig
|
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import Tacotron2Config
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
@ -23,7 +23,7 @@ config = Tacotron2Config(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
print_step=1,
|
print_step=1,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
use_speaker_embedding=True,
|
use_speaker_embedding=False,
|
||||||
use_d_vector_file=True,
|
use_d_vector_file=True,
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.tts.configs import Tacotron2Config
|
|
||||||
from TTS.tts.configs.shared_configs import GSTConfig
|
from TTS.tts.configs.shared_configs import GSTConfig
|
||||||
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
from TTS.tts.layers.losses import MSELossMasked
|
from TTS.tts.layers.losses import MSELossMasked
|
||||||
from TTS.tts.models.tacotron2 import Tacotron2
|
from TTS.tts.models.tacotron2 import Tacotron2
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -114,7 +114,7 @@ class MultiSpeakerTacotronTrainTest(unittest.TestCase):
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
outputs = model.forward(
|
outputs = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import Tacotron2Config
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.tts.configs import Tacotron2Config
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import Tacotron2Config
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import Tacotron2Config
|
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -6,7 +6,8 @@ import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.tts.configs import GSTConfig, TacotronConfig
|
from TTS.tts.configs.shared_configs import GSTConfig
|
||||||
|
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||||
from TTS.tts.layers.losses import L1LossMasked
|
from TTS.tts.layers.losses import L1LossMasked
|
||||||
from TTS.tts.models.tacotron import Tacotron
|
from TTS.tts.models.tacotron import Tacotron
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import TacotronConfig
|
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
Loading…
Reference in New Issue