diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 65950de6..8ed3578f 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -101,7 +101,8 @@ def check_config_and_model_args(config, arg_name, value): """Check the give argument in `config.model_args` if exist or in `config` for the given value. - It is to patch up the compatibility between models with and without `model_args`. + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. TODO: Remove this in the future with a unified approach. """ @@ -110,7 +111,7 @@ def check_config_and_model_args(config, arg_name, value): return config.model_args[arg_name] == value if hasattr(config, arg_name): return config[arg_name] == value - raise ValueError(f" [!] {arg_name} is not found in config or config.model_args") + return False def get_from_config_or_model_args(config, arg_name): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 62540ae2..905f50d7 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,7 +5,7 @@ import numpy as np import pysbd import torch -from TTS.config import load_config +from TTS.config import check_config_and_model_args, load_config from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager @@ -133,21 +133,23 @@ class Synthesizer(object): def _is_use_speaker_embedding(self): """Check if the speaker embedding is used in the model""" - # some models use model_args some don't + # we handle here the case that some models use model_args some don't + use_speaker_embedding = False if hasattr(self.tts_config, "model_args"): - config = self.tts_config.model_args - else: - config = self.tts_config - return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True + use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False) + use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False) + return use_speaker_embedding def _is_use_d_vector_file(self): """Check if the d-vector file is used in the model""" - # some models use model_args some don't + # we handle here the case that some models use model_args some don't + use_d_vector_file = False if hasattr(self.tts_config, "model_args"): config = self.tts_config.model_args - else: - config = self.tts_config - return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True + use_d_vector_file = config.get("use_d_vector_file", False) + config = self.tts_config + use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False) + return use_d_vector_file def _init_speaker_manager(self): """Initialize the SpeakerManager""" @@ -176,10 +178,7 @@ class Synthesizer(object): """Initialize the LanguageManager""" # setup if multi-lingual settings are in the global model config language_manager = None - if ( - hasattr(self.tts_config.model_args, "use_language_embedding") - and self.tts_config.model_args.use_language_embedding is True - ): + if check_config_and_model_args(self.tts_config, "use_language_embedding", True): if self.tts_languages_file: language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) elif self.tts_config.get("language_ids_file", None):