Load right char class dynamically

This commit is contained in:
Eren Gölge 2022-01-28 10:20:07 +01:00
parent ec4b03c045
commit d5c0e17548
2 changed files with 43 additions and 5 deletions

View File

@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.utils.generic_utils import get_import_path, import_class
class TTSTokenizer:
@ -152,15 +153,25 @@ class TTSTokenizer:
# init characters
if characters is None:
if config.use_phonemes:
# init phoneme set
characters, new_config = IPAPhonemes().init_from_config(config)
# set characters based on defined characters class
if config.characters and config.characters.characters_class:
CharactersClass = import_class(config.characters.characters_class)
characters, new_config = CharactersClass.init_from_config(config)
# set characters based on config
else:
# init character set
characters, new_config = Graphemes().init_from_config(config)
if config.use_phonemes:
# init phoneme set
characters, new_config = IPAPhonemes().init_from_config(config)
else:
# init character set
characters, new_config = Graphemes().init_from_config(config)
else:
characters, new_config = characters.init_from_config(config)
# set characters class
new_config.characters.characters_class = get_import_path(characters)
# init phonemizer
phonemizer = None
if config.use_phonemes:

View File

@ -95,6 +95,33 @@ def find_module(module_path: str, module_name: str) -> object:
return getattr(module, class_name)
def import_class(module_path: str) -> object:
"""Import a class from a module path.
Args:
module_path (str): The module path of the class.
Returns:
object: The imported class.
"""
class_name = module_path.split(".")[-1]
module_path = ".".join(module_path.split(".")[:-1])
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_import_path(obj: object) -> str:
"""Get the import path of a class.
Args:
obj (object): The class object.
Returns:
str: The import path of the class.
"""
return ".".join([type(obj).__module__, type(obj).__name__])
def get_user_data_dir(appname):
if sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel