mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
3b63d713b9
commit
04202da1ac
|
@ -291,7 +291,7 @@ class BaseTrainingConfig(Coqpit):
|
||||||
log_model_step (int):
|
log_model_step (int):
|
||||||
Number of steps required to log a checkpoint as W&B artifact
|
Number of steps required to log a checkpoint as W&B artifact
|
||||||
|
|
||||||
save_step (int):ipt
|
save_step (int):
|
||||||
Number of steps required to save the next checkpoint.
|
Number of steps required to save the next checkpoint.
|
||||||
|
|
||||||
checkpoint (bool):
|
checkpoint (bool):
|
||||||
|
|
|
@ -159,4 +159,3 @@ class BaseModel(nn.Module, ABC):
|
||||||
|
|
||||||
def format_batch(self):
|
def format_batch(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -522,14 +522,17 @@ class GlowTTS(BaseTTS):
|
||||||
# init characters
|
# init characters
|
||||||
if config.use_phonemes:
|
if config.use_phonemes:
|
||||||
from TTS.tts.utils.text.characters import IPAPhonemes
|
from TTS.tts.utils.text.characters import IPAPhonemes
|
||||||
|
|
||||||
characters = IPAPhonemes().init_from_config(config)
|
characters = IPAPhonemes().init_from_config(config)
|
||||||
else:
|
else:
|
||||||
from TTS.tts.utils.text.characters import Graphemes
|
from TTS.tts.utils.text.characters import Graphemes
|
||||||
|
|
||||||
characters = Graphemes().init_from_config(config)
|
characters = Graphemes().init_from_config(config)
|
||||||
config.num_chars = characters.num_chars
|
config.num_chars = characters.num_chars
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
ap = AudioProcessor.init_from_config(config)
|
ap = AudioProcessor.init_from_config(config)
|
||||||
tokenizer = TTSTokenizer.init_from_config(config)
|
tokenizer = TTSTokenizer.init_from_config(config)
|
||||||
speaker_manager = SpeakerManager.init_from_config(config)
|
speaker_manager = SpeakerManager.init_from_config(config)
|
||||||
return GlowTTS(config, ap, tokenizer, speaker_manager)
|
return GlowTTS(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class TTSTokenizer:
|
||||||
add_blank: bool = False,
|
add_blank: bool = False,
|
||||||
use_eos_bos=False,
|
use_eos_bos=False,
|
||||||
):
|
):
|
||||||
self.text_cleaner = text_cleaner or (lambda x: x)
|
self.text_cleaner = text_cleaner
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.add_blank = add_blank
|
self.add_blank = add_blank
|
||||||
self.use_eos_bos = use_eos_bos
|
self.use_eos_bos = use_eos_bos
|
||||||
|
@ -50,6 +50,16 @@ class TTSTokenizer:
|
||||||
self.not_found_characters = []
|
self.not_found_characters = []
|
||||||
self.phonemizer = phonemizer
|
self.phonemizer = phonemizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def characters(self):
|
||||||
|
return self._characters
|
||||||
|
|
||||||
|
@characters.setter
|
||||||
|
def characters(self, new_characters):
|
||||||
|
self._characters = new_characters
|
||||||
|
self.pad_id = self.characters.char_to_id(self.characters.pad)
|
||||||
|
self.blank_id = self.characters.char_to_id(self.characters.blank)
|
||||||
|
|
||||||
def encode(self, text: str) -> List[int]:
|
def encode(self, text: str) -> List[int]:
|
||||||
"""Encodes a string of text as a sequence of IDs."""
|
"""Encodes a string of text as a sequence of IDs."""
|
||||||
token_ids = []
|
token_ids = []
|
||||||
|
@ -61,6 +71,7 @@ class TTSTokenizer:
|
||||||
# discard but store not found characters
|
# discard but store not found characters
|
||||||
if char not in self.not_found_characters:
|
if char not in self.not_found_characters:
|
||||||
self.not_found_characters.append(char)
|
self.not_found_characters.append(char)
|
||||||
|
print(text)
|
||||||
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
@ -88,7 +99,8 @@ class TTSTokenizer:
|
||||||
5. Text to token IDs
|
5. Text to token IDs
|
||||||
"""
|
"""
|
||||||
# TODO: text cleaner should pick the right routine based on the language
|
# TODO: text cleaner should pick the right routine based on the language
|
||||||
text = self.text_cleaner(text)
|
if self.text_cleaner is not None:
|
||||||
|
text = self.text_cleaner(text)
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
text = self.phonemizer.phonemize(text, separator="")
|
text = self.phonemizer.phonemize(text, separator="")
|
||||||
if self.add_blank:
|
if self.add_blank:
|
||||||
|
@ -144,7 +156,9 @@ class TTSTokenizer:
|
||||||
if "phonemizer" in config and config.phonemizer:
|
if "phonemizer" in config and config.phonemizer:
|
||||||
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
|
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
|
||||||
else:
|
else:
|
||||||
phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs)
|
phonemizer = get_phonemizer_by_name(
|
||||||
|
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# init character set
|
# init character set
|
||||||
characters = Graphemes().init_from_config(config)
|
characters = Graphemes().init_from_config(config)
|
||||||
|
|
|
@ -56,10 +56,10 @@ class TestTTSTokenizer(unittest.TestCase):
|
||||||
self.ph = ESpeak("en-us")
|
self.ph = ESpeak("en-us")
|
||||||
self.tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph)
|
self.tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph)
|
||||||
self.assertEqual(len(self.tokenizer.not_found_characters), 0)
|
self.assertEqual(len(self.tokenizer.not_found_characters), 0)
|
||||||
text = "Yolk of one egg beaten light"
|
text = "Yolk of one egg beaten light"
|
||||||
ids = self.tokenizer_local.text_to_ids(text)
|
ids = self.tokenizer_local.text_to_ids(text)
|
||||||
text_hat = self.tokenizer_local.ids_to_text(ids)
|
text_hat = self.tokenizer_local.ids_to_text(ids)
|
||||||
self.assertEqual(self.tokenizer_local.not_found_characters, ['̩'])
|
self.assertEqual(self.tokenizer_local.not_found_characters, ["̩"])
|
||||||
self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt")
|
self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt")
|
||||||
|
|
||||||
def test_init_from_config(self):
|
def test_init_from_config(self):
|
||||||
|
|
Loading…
Reference in New Issue