mirror of https://github.com/coqui-ai/TTS.git
Load right char class dynamically
This commit is contained in:
parent
ec4b03c045
commit
d5c0e17548
|
@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union
|
|||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
|
||||
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
|
||||
from TTS.utils.generic_utils import get_import_path, import_class
|
||||
|
||||
|
||||
class TTSTokenizer:
|
||||
|
@ -152,15 +153,25 @@ class TTSTokenizer:
|
|||
|
||||
# init characters
|
||||
if characters is None:
|
||||
if config.use_phonemes:
|
||||
# init phoneme set
|
||||
characters, new_config = IPAPhonemes().init_from_config(config)
|
||||
# set characters based on defined characters class
|
||||
if config.characters and config.characters.characters_class:
|
||||
CharactersClass = import_class(config.characters.characters_class)
|
||||
characters, new_config = CharactersClass.init_from_config(config)
|
||||
# set characters based on config
|
||||
else:
|
||||
# init character set
|
||||
characters, new_config = Graphemes().init_from_config(config)
|
||||
if config.use_phonemes:
|
||||
# init phoneme set
|
||||
characters, new_config = IPAPhonemes().init_from_config(config)
|
||||
else:
|
||||
# init character set
|
||||
characters, new_config = Graphemes().init_from_config(config)
|
||||
|
||||
else:
|
||||
characters, new_config = characters.init_from_config(config)
|
||||
|
||||
# set characters class
|
||||
new_config.characters.characters_class = get_import_path(characters)
|
||||
|
||||
# init phonemizer
|
||||
phonemizer = None
|
||||
if config.use_phonemes:
|
||||
|
|
|
@ -95,6 +95,33 @@ def find_module(module_path: str, module_name: str) -> object:
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_class(module_path: str) -> object:
|
||||
"""Import a class from a module path.
|
||||
|
||||
Args:
|
||||
module_path (str): The module path of the class.
|
||||
|
||||
Returns:
|
||||
object: The imported class.
|
||||
"""
|
||||
class_name = module_path.split(".")[-1]
|
||||
module_path = ".".join(module_path.split(".")[:-1])
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_import_path(obj: object) -> str:
|
||||
"""Get the import path of a class.
|
||||
|
||||
Args:
|
||||
obj (object): The class object.
|
||||
|
||||
Returns:
|
||||
str: The import path of the class.
|
||||
"""
|
||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||
|
||||
|
||||
def get_user_data_dir(appname):
|
||||
if sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
|
Loading…
Reference in New Issue