mirror of https://github.com/coqui-ai/TTS.git
Start training by config.json using `register_config`
This commit is contained in:
parent
b3c073c99b
commit
ab563ce7cd
|
@ -1,8 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.config.shared_configs import *
|
from TTS.config.shared_configs import *
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
@ -20,7 +22,18 @@ def read_json_with_comments(json_path):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _search_configs(model_name):
|
def register_config(model_name: str) -> Coqpit:
|
||||||
|
"""Find the right config for the given model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Model name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModuleNotFoundError: No matching config for the model name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Coqpit: config class.
|
||||||
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
@ -33,7 +46,15 @@ def _search_configs(model_name):
|
||||||
return config_class
|
return config_class
|
||||||
|
|
||||||
|
|
||||||
def _process_model_name(config_dict):
|
def _process_model_name(config_dict: Dict) -> str:
|
||||||
|
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict (Dict): A dictionary including the config fields.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted modelname.
|
||||||
|
"""
|
||||||
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
||||||
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
||||||
return model_name
|
return model_name
|
||||||
|
@ -69,7 +90,7 @@ def load_config(config_path: str) -> None:
|
||||||
raise TypeError(f" [!] Unknown config file type {ext}")
|
raise TypeError(f" [!] Unknown config file type {ext}")
|
||||||
config_dict.update(data)
|
config_dict.update(data)
|
||||||
model_name = _process_model_name(config_dict)
|
model_name = _process_model_name(config_dict)
|
||||||
config_class = _search_configs(model_name.lower())
|
config_class = register_config(model_name.lower())
|
||||||
config = config_class()
|
config = config_class()
|
||||||
config.from_dict(config_dict)
|
config.from_dict(config_dict)
|
||||||
return config
|
return config
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch import nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.tts.datasets import load_meta_data
|
from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols
|
from TTS.tts.utils.text.symbols import parse_symbols
|
||||||
|
@ -940,7 +940,10 @@ def process_args(args, config=None):
|
||||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||||
logging to the console.
|
logging to the console.
|
||||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||||
the TensorBoard loggind.
|
the TensorBoard logging.
|
||||||
|
|
||||||
|
TODO:
|
||||||
|
- Interactive config definition.
|
||||||
"""
|
"""
|
||||||
if isinstance(args, tuple):
|
if isinstance(args, tuple):
|
||||||
args, coqpit_overrides = args
|
args, coqpit_overrides = args
|
||||||
|
@ -951,9 +954,17 @@ def process_args(args, config=None):
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
if not args.best_path:
|
if not args.best_path:
|
||||||
args.best_path = best_model
|
args.best_path = best_model
|
||||||
# setup output paths and read configs
|
# init config
|
||||||
if config is None:
|
if config is None and args.config_path:
|
||||||
|
# init from a file
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
|
else:
|
||||||
|
# init from console args
|
||||||
|
from TTS.config.shared_configs import BaseTrainingConfig
|
||||||
|
|
||||||
|
config_base = BaseTrainingConfig()
|
||||||
|
config_base.parse_known_args(coqpit_overrides)
|
||||||
|
config = register_config(config_base.model)()
|
||||||
# override values from command-line args
|
# override values from command-line args
|
||||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||||
if config.mixed_precision:
|
if config.mixed_precision:
|
||||||
|
|
Loading…
Reference in New Issue