mirror of https://github.com/coqui-ai/TTS.git
Add functions to get and check and argument in config and config.model_args
This commit is contained in:
parent
9ec6238f4a
commit
abedfd586d
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
|
@ -46,14 +46,14 @@ def main():
|
|||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init speaker manager
|
||||
if config.use_speaker_embedding:
|
||||
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
elif config.use_d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
|
|
|
@ -95,3 +95,27 @@ def load_config(config_path: str) -> None:
|
|||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def check_config_and_model_args(config, arg_name, value):
|
||||
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||
the given value.
|
||||
|
||||
It is to patch up the compatibility between models with and without `model_args`.
|
||||
|
||||
TODO: Remove this in the future with a unified approach.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name] == value
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name] == value
|
||||
raise ValueError(f" [!] {arg_name} is not found in config or config.model_args")
|
||||
|
||||
|
||||
def get_from_config_or_model_args(config, arg_name):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
return config[arg_name]
|
||||
|
|
Loading…
Reference in New Issue