mirror of https://github.com/coqui-ai/TTS.git
Start training by config.json using `register_config`
This commit is contained in:
parent
b3c073c99b
commit
ab563ce7cd
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config.shared_configs import *
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
@ -20,7 +22,18 @@ def read_json_with_comments(json_path):
|
|||
return data
|
||||
|
||||
|
||||
def _search_configs(model_name):
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: No matching config for the model name.
|
||||
|
||||
Returns:
|
||||
Coqpit: config class.
|
||||
"""
|
||||
config_class = None
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||
for path in paths:
|
||||
|
@ -33,7 +46,15 @@ def _search_configs(model_name):
|
|||
return config_class
|
||||
|
||||
|
||||
def _process_model_name(config_dict):
|
||||
def _process_model_name(config_dict: Dict) -> str:
|
||||
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
|
||||
|
||||
Args:
|
||||
config_dict (Dict): A dictionary including the config fields.
|
||||
|
||||
Returns:
|
||||
str: Formatted modelname.
|
||||
"""
|
||||
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
||||
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
||||
return model_name
|
||||
|
@ -69,7 +90,7 @@ def load_config(config_path: str) -> None:
|
|||
raise TypeError(f" [!] Unknown config file type {ext}")
|
||||
config_dict.update(data)
|
||||
model_name = _process_model_name(config_dict)
|
||||
config_class = _search_configs(model_name.lower())
|
||||
config_class = register_config(model_name.lower())
|
||||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import nn
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
|
@ -940,7 +940,10 @@ def process_args(args, config=None):
|
|||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard loggind.
|
||||
the TensorBoard logging.
|
||||
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
|
@ -951,9 +954,17 @@ def process_args(args, config=None):
|
|||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# setup output paths and read configs
|
||||
if config is None:
|
||||
# init config
|
||||
if config is None and args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(coqpit_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
if config.mixed_precision:
|
||||
|
|
Loading…
Reference in New Issue