mirror of https://github.com/coqui-ai/TTS.git
Discard OOV chars in tokenizer
Discard but store OOV chars with a warninig message when the OOV char first recognized
This commit is contained in:
parent
c39aaafbfc
commit
0fe39166fe
|
@ -8,6 +8,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize
|
||||||
class TTSTokenizer:
|
class TTSTokenizer:
|
||||||
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
||||||
|
|
||||||
|
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_phonemes (bool):
|
use_phonemes (bool):
|
||||||
Whether to use phonemes instead of characters. Defaults to False.
|
Whether to use phonemes instead of characters. Defaults to False.
|
||||||
|
@ -45,14 +47,21 @@ class TTSTokenizer:
|
||||||
self.add_blank = add_blank
|
self.add_blank = add_blank
|
||||||
self.use_eos_bos = use_eos_bos
|
self.use_eos_bos = use_eos_bos
|
||||||
self.characters = characters
|
self.characters = characters
|
||||||
|
self.not_found_characters = []
|
||||||
self.phonemizer = phonemizer
|
self.phonemizer = phonemizer
|
||||||
|
|
||||||
def encode(self, text: str) -> List[int]:
|
def encode(self, text: str) -> List[int]:
|
||||||
"""Encodes a string of text as a sequence of IDs."""
|
"""Encodes a string of text as a sequence of IDs."""
|
||||||
token_ids = []
|
token_ids = []
|
||||||
for char in text:
|
for char in text:
|
||||||
|
try:
|
||||||
idx = self.characters.char_to_id(char)
|
idx = self.characters.char_to_id(char)
|
||||||
token_ids.append(idx)
|
token_ids.append(idx)
|
||||||
|
except KeyError:
|
||||||
|
# discard but store not found characters
|
||||||
|
if char not in self.not_found_characters:
|
||||||
|
self.not_found_characters.append(char)
|
||||||
|
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
def decode(self, token_ids: List[int]) -> str:
|
def decode(self, token_ids: List[int]) -> str:
|
||||||
|
@ -109,6 +118,10 @@ class TTSTokenizer:
|
||||||
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
|
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}")
|
print(f"{indent}| > phonemizer: {self.phonemizer.print_logs(level + 1)}")
|
||||||
|
if len(self.not_found_characters) > 0:
|
||||||
|
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
|
||||||
|
for char in self.not_found_characters:
|
||||||
|
print(f"{indent}| > {char}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "Coqpit"):
|
def init_from_config(config: "Coqpit"):
|
||||||
|
|
Loading…
Reference in New Issue