mirror of https://github.com/coqui-ai/TTS.git
Add functions to get and check and argument in config and config.model_args
This commit is contained in:
parent
3c6d7f495c
commit
4c50f6f4df
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||||
from TTS.trainer import Trainer, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
|
@ -46,14 +46,14 @@ def main():
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
||||||
# init speaker manager
|
# init speaker manager
|
||||||
if config.use_speaker_embedding:
|
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
||||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
else:
|
else:
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
elif config.use_d_vector_file:
|
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -95,3 +95,27 @@ def load_config(config_path: str) -> None:
|
||||||
config = config_class()
|
config = config_class()
|
||||||
config.from_dict(config_dict)
|
config.from_dict(config_dict)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def check_config_and_model_args(config, arg_name, value):
|
||||||
|
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||||
|
the given value.
|
||||||
|
|
||||||
|
It is to patch up the compatibility between models with and without `model_args`.
|
||||||
|
|
||||||
|
TODO: Remove this in the future with a unified approach.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name] == value
|
||||||
|
if hasattr(config, arg_name):
|
||||||
|
return config[arg_name] == value
|
||||||
|
raise ValueError(f" [!] {arg_name} is not found in config or config.model_args")
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_config_or_model_args(config, arg_name):
|
||||||
|
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name]
|
||||||
|
return config[arg_name]
|
||||||
|
|
Loading…
Reference in New Issue