From bb389479a46b5843ce765381a0ddff3078cacbb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 18:46:18 +0100 Subject: [PATCH] Update setup_model for TTS.tts models --- TTS/tts/models/__init__.py | 44 ++------------------------------------ 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index c8371106..cb1c2e21 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,52 +1,12 @@ -from TTS.tts.utils.text.characters import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): +def setup_model(config: "Coqpit") -> "BaseTTS": print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) else: MyModel = find_module("TTS.tts.models", config.model.lower()) - # define set of characters used by the model - if config.characters is not None: - # set characters from config - if hasattr(MyModel, "make_symbols"): - symbols = MyModel.make_symbols(config) - else: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.characters import phonemes, symbols # pylint: disable=import-outside-toplevel - - if config.use_phonemes: - symbols = phonemes - # use default characters and assign them to config - config.characters = parse_symbols() - # consider special `blank` character if `add_blank` is set True - num_chars = len(symbols) + getattr(config, "add_blank", False) - config.num_chars = num_chars - # compatibility fix - if "model_params" in config: - config.model_params.num_chars = num_chars - if "model_args" in config: - config.model_args.num_chars = num_chars - if config.model.lower() in ["vits"]: # If model supports multiple languages - model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) - else: - model = MyModel(config, speaker_manager=speaker_manager) + model = MyModel.init_from_config(config) return model - - -# TODO; class registery -# def import_models(models_dir, namespace): -# for file in os.listdir(models_dir): -# path = os.path.join(models_dir, file) -# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): -# model_name = file[: file.find(".py")] if file.endswith(".py") else file -# importlib.import_module(namespace + "." + model_name) -# -# -## automatically import any Python files in the models/ directory -# models_dir = os.path.dirname(__file__) -# import_models(models_dir, "TTS.tts.models")