From d5c0e17548e0a846dc0fb77653965274ae58d880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 Jan 2022 10:20:07 +0100 Subject: [PATCH] Load right char class dynamically --- TTS/tts/utils/text/tokenizer.py | 21 ++++++++++++++++----- TTS/utils/generic_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 80be368d..bdaf8ea6 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.utils.generic_utils import get_import_path, import_class class TTSTokenizer: @@ -152,15 +153,25 @@ class TTSTokenizer: # init characters if characters is None: - if config.use_phonemes: - # init phoneme set - characters, new_config = IPAPhonemes().init_from_config(config) + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config else: - # init character set - characters, new_config = Graphemes().init_from_config(config) + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + else: characters, new_config = characters.init_from_config(config) + # set characters class + new_config.characters.characters_class = get_import_path(characters) + # init phonemizer phonemizer = None if config.use_phonemes: diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 6504cca6..69609bcb 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -95,6 +95,33 @@ def find_module(module_path: str, module_name: str) -> object: return getattr(module, class_name) +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel