mirror of https://github.com/coqui-ai/TTS.git
86 lines
3.1 KiB
Python
86 lines
3.1 KiB
Python
import re
|
||
from itertools import groupby
|
||
from typing import Dict, List
|
||
|
||
|
||
class Tokenizer:
|
||
_CHARS_TO_IGNORE = "[\,\?\.\!\-\;\:\[\]\(\)'\`\"\“\”\’\\n]"
|
||
|
||
def __init__(self, transcripts: List[str] = None, vocab_dict: Dict = None, chars_to_ignore: str = None):
|
||
super().__init__()
|
||
|
||
self.chars_to_ignore = chars_to_ignore if chars_to_ignore else self._CHARS_TO_IGNORE
|
||
if transcripts:
|
||
_, vocab_dict = self.extract_vocab(transcripts)
|
||
self.encode_dict = vocab_dict
|
||
self.decode_dict = {v: k for k, v in vocab_dict.items()}
|
||
assert len(self.encode_dict) == len(self.decode_dict)
|
||
|
||
self.unk_token = "[UNK]"
|
||
self.pad_token = "[PAD]"
|
||
self.word_del_token = "|"
|
||
|
||
if self.unk_token not in self.encode_dict:
|
||
self.encode_dict[self.unk_token] = len(self.encode_dict)
|
||
self.decode_dict[len(self.decode_dict)] = self.unk_token
|
||
|
||
if self.pad_token not in self.encode_dict:
|
||
self.encode_dict[self.pad_token] = len(self.encode_dict)
|
||
self.decode_dict[len(self.decode_dict)] = self.pad_token
|
||
|
||
self.unk_token_id = self.encode_dict[self.unk_token]
|
||
self.pad_token_id = self.encode_dict[self.pad_token]
|
||
|
||
@property
|
||
def vocab_size(self):
|
||
return len(self.encode_dict)
|
||
|
||
@property
|
||
def vocab_dict(self):
|
||
return self.encode_dict
|
||
|
||
def encode(self, text: str):
|
||
"""Convert input text to sequence of ids."""
|
||
text = text.lower()
|
||
text = self.remove_special_characters(text)
|
||
text = text.replace(" ", self.word_del_token)
|
||
ids = [self.encode_dict.get(char, self.unk_token) for char in text]
|
||
return ids
|
||
|
||
def decode(self, token_ids: List[int]):
|
||
"""Convert input sequence of ids to text"""
|
||
tokens = [self.decode_dict.get(ti, self.unk_token) for ti in token_ids]
|
||
return tokens
|
||
|
||
def tokens_to_string(self, tokens, group_tokens=True):
|
||
# group same tokens into non-repeating tokens in CTC style decoding
|
||
if group_tokens:
|
||
tokens = [token_group[0] for token_group in groupby(tokens)]
|
||
|
||
# filter self.pad_token which is used as CTC-blank token
|
||
filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens))
|
||
|
||
# replace delimiter token
|
||
string = "".join([" " if token == self.word_del_token else token for token in filtered_tokens]).strip()
|
||
return string
|
||
|
||
def remove_special_characters(self, text):
|
||
text = re.sub(self.chars_to_ignore, "", text).lower()
|
||
return text
|
||
|
||
def extract_vocab(self, texts: List[str]):
|
||
vocab = set()
|
||
# for waveform, sr, _, text in dataset:
|
||
for text in texts:
|
||
text = text.lower()
|
||
text = self.remove_special_characters(text)
|
||
vocab.update(text)
|
||
vocab = list(vocab)
|
||
vocab_list = list(set(vocab))
|
||
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
|
||
vocab_dict["|"] = vocab_dict[" "]
|
||
del vocab_dict[" "]
|
||
vocab_dict["[UNK]"] = len(vocab_dict)
|
||
vocab_dict["[PAD]"] = len(vocab_dict)
|
||
return vocab, vocab_dict
|