From 757e90b1cc69fb928b869c7246d92b1a5d2aaceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 May 2021 03:40:34 +0200 Subject: [PATCH] load_config function to initialize the right Coqpit for the given model --- TTS/config/__init__.py | 40 ++++++++++++++++++++++++++++++++++++++++ TTS/utils/arguments.py | 3 ++- 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 TTS/config/__init__.py diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py new file mode 100644 index 00000000..85e7d9b9 --- /dev/null +++ b/TTS/config/__init__.py @@ -0,0 +1,40 @@ +from TTS.config.shared_configs import * + +import json +import os + +import yaml + +from TTS.utils.generic_utils import find_module + + +def _search_configs(model_name): + config_class = None + paths = ["TTS.tts.configs", "TTS.vocoder.configs"] + for path in paths: + try: + config_class = find_module(path, model_name + "_config") + except ModuleNotFoundError: + pass + if config_class is None: + raise ModuleNotFoundError() + return config_class + + +def load_config(config_path: str) -> None: + config_dict = {} + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + elif ext == ".json": + with open(config_path, "r", encoding="utf-8") as f: + input_str = f.read() + data = json.loads(input_str) + else: + raise TypeError(f" [!] Unknown config file type {ext}") + config_dict.update(data) + config_class = _search_configs(config_dict["model"].lower()) + config = config_class() + config.from_dict(config_dict) + return config diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 35fa80eb..cf64edae 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -12,7 +12,8 @@ import torch from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import create_experiment_folder, get_git_branch -from TTS.utils.io import copy_model_files, load_config +from TTS.utils.io import copy_model_files +from TTS.config import load_config from TTS.utils.tensorboard_logger import TensorboardLogger