diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index b4f1cbea..ecbe1f9a 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -1,8 +1,10 @@ import json import os import re +from typing import Dict import yaml +from coqpit import Coqpit from TTS.config.shared_configs import * from TTS.utils.generic_utils import find_module @@ -20,7 +22,18 @@ def read_json_with_comments(json_path): return data -def _search_configs(model_name): +def register_config(model_name: str) -> Coqpit: + """Find the right config for the given model name. + + Args: + model_name (str): Model name. + + Raises: + ModuleNotFoundError: No matching config for the model name. + + Returns: + Coqpit: config class. + """ config_class = None paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"] for path in paths: @@ -33,7 +46,15 @@ def _search_configs(model_name): return config_class -def _process_model_name(config_dict): +def _process_model_name(config_dict: Dict) -> str: + """Format the model name as expected. It is a band-aid for the old `vocoder` model names. + + Args: + config_dict (Dict): A dictionary including the config fields. + + Returns: + str: Formatted modelname. + """ model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] model_name = model_name.replace("_generator", "").replace("_discriminator", "") return model_name @@ -69,7 +90,7 @@ def load_config(config_path: str) -> None: raise TypeError(f" [!] Unknown config file type {ext}") config_dict.update(data) model_name = _process_model_name(config_dict) - config_class = _search_configs(model_name.lower()) + config_class = register_config(model_name.lower()) config = config_class() config.from_dict(config_dict) return config diff --git a/TTS/trainer.py b/TTS/trainer.py index d5aec1c9..e3403bae 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -18,7 +18,7 @@ from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader -from TTS.config import load_config +from TTS.config import load_config, register_config from TTS.tts.datasets import load_meta_data from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.utils.text.symbols import parse_symbols @@ -940,7 +940,10 @@ def process_args(args, config=None): c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console. tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does - the TensorBoard loggind. + the TensorBoard logging. + + TODO: + - Interactive config definition. """ if isinstance(args, tuple): args, coqpit_overrides = args @@ -951,9 +954,17 @@ def process_args(args, config=None): args.restore_path, best_model = get_last_checkpoint(args.continue_path) if not args.best_path: args.best_path = best_model - # setup output paths and read configs - if config is None: + # init config + if config is None and args.config_path: + # init from a file config = load_config(args.config_path) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig + + config_base = BaseTrainingConfig() + config_base.parse_known_args(coqpit_overrides) + config = register_config(config_base.model)() # override values from command-line args config.parse_known_args(coqpit_overrides, relaxed_parser=True) if config.mixed_precision: