From 690de1ab06c4709acfade4a399cbfa4299a29784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 23 Feb 2022 12:54:36 +0100 Subject: [PATCH] Update Characters and add more tests --- TTS/tts/utils/text/characters.py | 57 ++++++----------------------- tests/text_tests/test_characters.py | 46 ++++++++++++++++++++++- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 0ce65a90..1b375e4f 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -35,51 +35,6 @@ _diacrilics = "ɚ˞ɫ" _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -# def create_graphemes( -# characters=_characters, -# punctuations=_punctuations, -# pad=_pad, -# eos=_eos, -# bos=_bos, -# blank=_blank, -# unique=True, -# ): # pylint: disable=redefined-outer-name -# """Function to create default characters and phonemes""" -# # create graphemes -# = ( -# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) -# ) # this is to keep previous models compatible. -# _graphemes = list(characters) -# _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes -# _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes -# _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes -# _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes -# _graphemes = _graphemes + list(punctuations) -# return _graphemes, _phonemes - - -# def create_phonemes( -# phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True -# ): -# # create phonemes -# _phonemes = None -# _phonemes_sorted = ( -# sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) -# ) # this is to keep previous models compatible. -# _phonemes = list(_phonemes_sorted) -# _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes -# _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes -# _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes -# _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes -# _phonemes = _phonemes + list(punctuations) -# _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) -# return _phonemes - - -# DEF_GRAPHEMES = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos) -# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank) - - class BaseVocabulary: """Base Vocabulary class. @@ -98,18 +53,24 @@ class BaseVocabulary: @property def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" return self.char_to_id(self.pad) if self.pad else len(self.vocab) @property def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" return self.char_to_id(self.blank) if self.blank else len(self.vocab) @property def vocab(self): + """Return the vocabulary dictionary.""" return self._vocab @vocab.setter def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} self._id_to_char = { @@ -118,6 +79,7 @@ class BaseVocabulary: @staticmethod def init_from_config(config, **kwargs): + """Initialize from the given config.""" if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: return ( BaseVocabulary( @@ -133,15 +95,18 @@ class BaseVocabulary: @property def num_chars(self): - return max(self._vocab.values()) + 1 + """Return number of tokens in the vocabulary.""" + return len(self._vocab) def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" try: return self._char_to_id[char] except KeyError as e: raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" return self._id_to_char[idx] diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 5432c652..2ebb3bc3 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,9 +1,53 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes +from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary # pylint: disable=protected-access +class BaseVocabularyTest(unittest.TestCase): + def setUp(self): + self.phonemes = IPAPhonemes() + self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos) + self.empty_vocab = BaseVocabulary({}) + + def test_pad_id(self): + self.assertEqual(self.empty_vocab.pad_id, 0) + self.assertEqual(self.base_vocab.pad_id, self.phonemes.pad_id) + + def test_blank_id(self): + self.assertEqual(self.empty_vocab.blank_id, 0) + self.assertEqual(self.base_vocab.blank_id, self.phonemes.blank_id) + + def test_vocab(self): + self.assertEqual(self.empty_vocab.vocab, {}) + self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab) + + def test_init_from_config(self): + ... + + def test_num_chars(self): + self.assertEqual(self.empty_vocab.num_chars, 0) + self.assertEqual(self.base_vocab.num_chars, self.phonemes.num_chars) + + def test_char_to_id(self): + try: + self.empty_vocab.char_to_id("a") + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + self.assertEqual(self.base_vocab.char_to_id(k), self.phonemes.char_to_id(k)) + + def test_id_to_char(self): + try: + self.empty_vocab.id_to_char(0) + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + v = self.phonemes.char_to_id(k) + self.assertEqual(self.base_vocab.id_to_char(v), self.phonemes.id_to_char(v)) + class BaseCharacterTest(unittest.TestCase): def setUp(self):