load_config function to initialize the right Coqpit for the given model

This commit is contained in:
Eren Gölge 2021-05-07 03:40:34 +02:00
parent e6f45b9eb7
commit 757e90b1cc
2 changed files with 42 additions and 1 deletions

40
TTS/config/__init__.py Normal file
View File

@ -0,0 +1,40 @@
from TTS.config.shared_configs import *
import json
import os
import yaml
from TTS.utils.generic_utils import find_module
def _search_configs(model_name):
config_class = None
paths = ["TTS.tts.configs", "TTS.vocoder.configs"]
for path in paths:
try:
config_class = find_module(path, model_name + "_config")
except ModuleNotFoundError:
pass
if config_class is None:
raise ModuleNotFoundError()
return config_class
def load_config(config_path: str) -> None:
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
elif ext == ".json":
with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read()
data = json.loads(input_str)
else:
raise TypeError(f" [!] Unknown config file type {ext}")
config_dict.update(data)
config_class = _search_configs(config_dict["model"].lower())
config = config_class()
config.from_dict(config_dict)
return config

View File

@ -12,7 +12,8 @@ import torch
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.io import copy_model_files
from TTS.config import load_config
from TTS.utils.tensorboard_logger import TensorboardLogger