From 4c50f6f4df6a2ed11958662fef9fdf226239e402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 16 Dec 2021 14:53:57 +0000 Subject: [PATCH] Add functions to get and check and argument in config and config.model_args --- TTS/bin/train_tts.py | 8 ++++---- TTS/config/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 191cba00..3360a940 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,6 +1,6 @@ import os -from TTS.config import load_config, register_config +from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model @@ -46,14 +46,14 @@ def main(): ap = AudioProcessor(**config.audio) # init speaker manager - if config.use_speaker_embedding: + if check_config_and_model_args(config, "use_speaker_embedding", True): speaker_manager = SpeakerManager(data_items=train_samples + eval_samples) if hasattr(config, "model_args"): config.model_args.num_speakers = speaker_manager.num_speakers else: config.num_speakers = speaker_manager.num_speakers - elif config.use_d_vector_file: - speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + elif check_config_and_model_args(config, "use_d_vector_file", True): + speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) if hasattr(config, "model_args"): config.model_args.num_speakers = speaker_manager.num_speakers else: diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index f626163f..65950de6 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -95,3 +95,27 @@ def load_config(config_path: str) -> None: config = config_class() config.from_dict(config_dict) return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + It is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + raise ValueError(f" [!] {arg_name} is not found in config or config.model_args") + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name]