diff --git a/TTS/stt/datasets/tokenizer.py b/TTS/stt/datasets/tokenizer.py new file mode 100644 index 00000000..94e8d4f4 --- /dev/null +++ b/TTS/stt/datasets/tokenizer.py @@ -0,0 +1,85 @@ +import re +from itertools import groupby +from typing import Dict, List + + +class Tokenizer: + _CHARS_TO_IGNORE = "[\,\?\.\!\-\;\:\[\]\(\)'\`\"\“\”\’\\n]" + + def __init__(self, transcripts: List[str] = None, vocab_dict: Dict = None, chars_to_ignore: str = None): + super().__init__() + + self.chars_to_ignore = chars_to_ignore if chars_to_ignore else self._CHARS_TO_IGNORE + if transcripts: + _, vocab_dict = self.extract_vocab(transcripts) + self.encode_dict = vocab_dict + self.decode_dict = {v: k for k, v in vocab_dict.items()} + assert len(self.encode_dict) == len(self.decode_dict) + + self.unk_token = "[UNK]" + self.pad_token = "[PAD]" + self.word_del_token = "|" + + if self.unk_token not in self.encode_dict: + self.encode_dict[self.unk_token] = len(self.encode_dict) + self.decode_dict[len(self.decode_dict)] = self.unk_token + + if self.pad_token not in self.encode_dict: + self.encode_dict[self.pad_token] = len(self.encode_dict) + self.decode_dict[len(self.decode_dict)] = self.pad_token + + self.unk_token_id = self.encode_dict[self.unk_token] + self.pad_token_id = self.encode_dict[self.pad_token] + + @property + def vocab_size(self): + return len(self.encode_dict) + + @property + def vocab_dict(self): + return self.encode_dict + + def encode(self, text: str): + """Convert input text to sequence of ids.""" + text = text.lower() + text = self.remove_special_characters(text) + text = text.replace(" ", self.word_del_token) + ids = [self.encode_dict.get(char, self.unk_token) for char in text] + return ids + + def decode(self, token_ids: List[int]): + """Convert input sequence of ids to text""" + tokens = [self.decode_dict.get(ti, self.unk_token) for ti in token_ids] + return tokens + + def tokens_to_string(self, tokens, group_tokens=True): + # group same tokens into non-repeating tokens in CTC style decoding + if group_tokens: + tokens = [token_group[0] for token_group in groupby(tokens)] + + # filter self.pad_token which is used as CTC-blank token + filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens)) + + # replace delimiter token + string = "".join([" " if token == self.word_del_token else token for token in filtered_tokens]).strip() + return string + + def remove_special_characters(self, text): + text = re.sub(self.chars_to_ignore, "", text).lower() + return text + + def extract_vocab(self, texts: List[str]): + vocab = set() + # for waveform, sr, _, text in dataset: + for text in texts: + text = text.lower() + text = self.remove_special_characters(text) + vocab.update(text) + vocab = list(vocab) + vocab_list = list(set(vocab)) + vocab_dict = {v: k for k, v in enumerate(vocab_list)} + vocab_dict["|"] = vocab_dict[" "] + del vocab_dict[" "] + vocab_dict["[UNK]"] = len(vocab_dict) + vocab_dict["[PAD]"] = len(vocab_dict) + return vocab, vocab_dict