Add functions to get and check and argument in config and config.model_args

This commit is contained in:
Eren Gölge 2021-12-16 14:53:57 +00:00
parent 9ec6238f4a
commit abedfd586d
2 changed files with 28 additions and 4 deletions

View File

@ -1,6 +1,6 @@
import os
from TTS.config import load_config, register_config
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
@ -46,14 +46,14 @@ def main():
ap = AudioProcessor(**config.audio)
# init speaker manager
if config.use_speaker_embedding:
if check_config_and_model_args(config, "use_speaker_embedding", True):
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
elif config.use_d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
elif check_config_and_model_args(config, "use_d_vector_file", True):
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:

View File

@ -95,3 +95,27 @@ def load_config(config_path: str) -> None:
config = config_class()
config.from_dict(config_dict)
return config
def check_config_and_model_args(config, arg_name, value):
"""Check the give argument in `config.model_args` if exist or in `config` for
the given value.
It is to patch up the compatibility between models with and without `model_args`.
TODO: Remove this in the future with a unified approach.
"""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name] == value
if hasattr(config, arg_name):
return config[arg_name] == value
raise ValueError(f" [!] {arg_name} is not found in config or config.model_args")
def get_from_config_or_model_args(config, arg_name):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
return config[arg_name]