mirror of https://github.com/coqui-ai/TTS.git
Implement BasePhonemizer
This commit is contained in:
parent
dcd01356e0
commit
c1119bc291
|
@ -0,0 +1,51 @@
|
|||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer
|
||||
|
||||
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)}
|
||||
|
||||
|
||||
ESPEAK_LANGS = list(ESpeak.supported_languages().keys())
|
||||
GRUUT_LANGS = list(Gruut.supported_languages())
|
||||
|
||||
|
||||
# Dict setting default phonemizers for each language
|
||||
DEF_LANG_TO_PHONEMIZER = {
|
||||
"ja-jp": JA_JP_Phonemizer.name(),
|
||||
"zh-cn": ZH_CN_Phonemizer.name(),
|
||||
}
|
||||
|
||||
|
||||
# Add Gruut languages
|
||||
_ = [Gruut.name()] * len(GRUUT_LANGS)
|
||||
_new_dict = dict(list(zip(GRUUT_LANGS, _)))
|
||||
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
|
||||
|
||||
|
||||
# Add ESpeak languages and override any existing ones
|
||||
_ = [ESpeak.name()] * len(ESPEAK_LANGS)
|
||||
_new_dict = dict(list(zip(list(ESPEAK_LANGS), _)))
|
||||
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
|
||||
|
||||
|
||||
def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
|
||||
"""Initiate a phonemizer by name
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
Name of the phonemizer that should match `phonemizer.name()`.
|
||||
|
||||
kwargs (dict):
|
||||
Extra keyword arguments that should be passed to the phonemizer.
|
||||
"""
|
||||
if name == "espeak":
|
||||
return ESpeak(**kwargs)
|
||||
if name == "gruut":
|
||||
return Gruut(**kwargs)
|
||||
if name == "zh_cn_phonemizer":
|
||||
return ZH_CN_Phonemizer(**kwargs)
|
||||
if name == "ja_jp_phonemizer":
|
||||
return JA_JP_Phonemizer(**kwargs)
|
||||
raise ValueError(f"Phonemizer {name} not found")
|
|
@ -0,0 +1,136 @@
|
|||
import abc
|
||||
import itertools
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
|
||||
class BasePhonemizer(abc.ABC):
|
||||
"""Base phonemizer class
|
||||
|
||||
Args:
|
||||
language (str):
|
||||
Language used by the phonemizer.
|
||||
|
||||
punctuations (List[str]):
|
||||
List of punctuation marks to be preserved.
|
||||
|
||||
keep_puncs (bool):
|
||||
Whether to preserve punctuation marks or not.
|
||||
"""
|
||||
|
||||
def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
|
||||
|
||||
# ensure the backend is installed on the system
|
||||
if not self.is_available():
|
||||
raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
|
||||
|
||||
# ensure the backend support the requested language
|
||||
self._language = self._init_language(language)
|
||||
|
||||
# setup punctuation processing
|
||||
self._keep_puncs = keep_puncs
|
||||
self._punctuator = Punctuation(punctuations)
|
||||
|
||||
|
||||
def _init_language(self, language):
|
||||
"""Language initialization
|
||||
|
||||
This method may be overloaded in child classes (see Segments backend)
|
||||
|
||||
"""
|
||||
if not self.is_supported_language(language):
|
||||
raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
|
||||
return language
|
||||
|
||||
@property
|
||||
def language(self):
|
||||
"""The language code configured to be used for phonemization"""
|
||||
return self._language
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def name():
|
||||
"""The name of the backend"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def is_available(cls):
|
||||
"""Returns True if the backend is installed, False otherwise"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def version(cls):
|
||||
"""Return the backend version as a tuple (major, minor, patch)"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def supported_languages():
|
||||
"""Return a dict of language codes -> name supported by the backend"""
|
||||
|
||||
def is_supported_language(self, language):
|
||||
"""Returns True if `language` is supported by the backend"""
|
||||
return language in self.supported_languages()
|
||||
|
||||
fr"""
|
||||
Phonemization follows the following steps:
|
||||
1. Preprocessing:
|
||||
- remove empty lines
|
||||
- remove punctuation
|
||||
- keep track of punctuation marks
|
||||
|
||||
2. Phonemization:
|
||||
- convert text to phonemes
|
||||
|
||||
3. Postprocessing:
|
||||
- join phonemes
|
||||
- restore punctuation marks
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _phonemize(self, text, separator):
|
||||
"""The main phonemization method"""
|
||||
|
||||
def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
|
||||
"""Preprocess the text before phonemization
|
||||
|
||||
Override this if you need a different behaviour
|
||||
"""
|
||||
if self._keep_puncs:
|
||||
# a tuple (text, punctuation marks)
|
||||
return self._punctuator.strip_to_restore(text)
|
||||
return [self._punctuator.strip(text)], []
|
||||
|
||||
def _phonemize_postprocess(self, phonemized, punctuations) -> str:
|
||||
"""Postprocess the raw phonemized output
|
||||
|
||||
Override this if you need a different behaviour
|
||||
"""
|
||||
if self._keep_puncs:
|
||||
return self._punctuator.restore(phonemized, punctuations)[0]
|
||||
return phonemized[0]
|
||||
|
||||
def phonemize(self, text: str, separator="|") -> str:
|
||||
"""Returns the `text` phonemized for the given language
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Text to be phonemized.
|
||||
|
||||
separator (str):
|
||||
string separator used between phonemes. Default to '_'.
|
||||
|
||||
Returns:
|
||||
(str): Phonemized text
|
||||
"""
|
||||
text, punctuations = self._phonemize_preprocess(text)
|
||||
phonemized = []
|
||||
for t in text:
|
||||
p = self._phonemize(t, separator)
|
||||
phonemized.append(p)
|
||||
phonemized = self._phonemize_postprocess(phonemized, punctuations)
|
||||
return phonemized
|
||||
|
||||
def print_logs(self, level: int=0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > phoneme language: {self.language}")
|
||||
print(f"{indent}| > phoneme backend: {self.name()}")
|
Loading…
Reference in New Issue