diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index c6390beb..8c1bd430 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -4,20 +4,23 @@ from TTS.utils.generic_utils import find_module def setup_model(config): print(" > Using model: {}".format(config.model)) - MyModel = find_module("TTS.tts.models", config.model.lower()) # define set of characters used by the model if config.characters is not None: # set characters from config - symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name + if hasattr(MyModel, "make_symbols"): + symbols = MyModel.make_symbols(config) + else: + symbols, phonemes = make_symbols(**config.characters) else: from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel + if config.use_phonemes: + symbols = phonemes # use default characters and assign them to config config.characters = parse_symbols() - num_chars = len(phonemes) if config.use_phonemes else len(symbols) # consider special `blank` character if `add_blank` is set True - num_chars = num_chars + getattr(config, "add_blank", False) + num_chars = len(symbols) + getattr(config, "add_blank", False) config.num_chars = num_chars # compatibility fix if "model_params" in config: