Start training by config.json using `register_config`

This commit is contained in:
Eren Gölge 2021-06-26 18:33:17 +02:00
parent b3c073c99b
commit ab563ce7cd
2 changed files with 39 additions and 7 deletions

View File

@ -1,8 +1,10 @@
import json
import os
import re
from typing import Dict
import yaml
from coqpit import Coqpit
from TTS.config.shared_configs import *
from TTS.utils.generic_utils import find_module
@ -20,7 +22,18 @@ def read_json_with_comments(json_path):
return data
def _search_configs(model_name):
def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name.
Args:
model_name (str): Model name.
Raises:
ModuleNotFoundError: No matching config for the model name.
Returns:
Coqpit: config class.
"""
config_class = None
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
for path in paths:
@ -33,7 +46,15 @@ def _search_configs(model_name):
return config_class
def _process_model_name(config_dict):
def _process_model_name(config_dict: Dict) -> str:
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
Args:
config_dict (Dict): A dictionary including the config fields.
Returns:
str: Formatted modelname.
"""
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
return model_name
@ -69,7 +90,7 @@ def load_config(config_path: str) -> None:
raise TypeError(f" [!] Unknown config file type {ext}")
config_dict.update(data)
model_name = _process_model_name(config_dict)
config_class = _search_configs(model_name.lower())
config_class = register_config(model_name.lower())
config = config_class()
config.from_dict(config_dict)
return config

View File

@ -18,7 +18,7 @@ from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from TTS.config import load_config
from TTS.config import load_config, register_config
from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.text.symbols import parse_symbols
@ -940,7 +940,10 @@ def process_args(args, config=None):
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard loggind.
the TensorBoard logging.
TODO:
- Interactive config definition.
"""
if isinstance(args, tuple):
args, coqpit_overrides = args
@ -951,9 +954,17 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# setup output paths and read configs
if config is None:
# init config
if config is None and args.config_path:
# init from a file
config = load_config(args.config_path)
else:
# init from console args
from TTS.config.shared_configs import BaseTrainingConfig
config_base = BaseTrainingConfig()
config_base.parse_known_args(coqpit_overrides)
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
if config.mixed_precision: