Update Characters and add more tests

This commit is contained in:
Eren Gölge 2022-02-23 12:54:36 +01:00
parent 7de5afc29a
commit 690de1ab06
2 changed files with 56 additions and 47 deletions

View File

@ -35,51 +35,6 @@ _diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
# def create_graphemes(
# characters=_characters,
# punctuations=_punctuations,
# pad=_pad,
# eos=_eos,
# bos=_bos,
# blank=_blank,
# unique=True,
# ): # pylint: disable=redefined-outer-name
# """Function to create default characters and phonemes"""
# # create graphemes
# = (
# sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
# ) # this is to keep previous models compatible.
# _graphemes = list(characters)
# _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes
# _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes
# _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes
# _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes
# _graphemes = _graphemes + list(punctuations)
# return _graphemes, _phonemes
# def create_phonemes(
# phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True
# ):
# # create phonemes
# _phonemes = None
# _phonemes_sorted = (
# sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
# ) # this is to keep previous models compatible.
# _phonemes = list(_phonemes_sorted)
# _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes
# _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes
# _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes
# _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes
# _phonemes = _phonemes + list(punctuations)
# _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
# return _phonemes
# DEF_GRAPHEMES = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
class BaseVocabulary:
"""Base Vocabulary class.
@ -98,18 +53,24 @@ class BaseVocabulary:
@property
def pad_id(self) -> int:
"""Return the index of the padding character. If the padding character is not specified, return the length
of the vocabulary."""
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
"""Return the index of the blank character. If the blank character is not specified, return the length of
the vocabulary."""
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
@vocab.setter
def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
@ -118,6 +79,7 @@ class BaseVocabulary:
@staticmethod
def init_from_config(config, **kwargs):
"""Initialize from the given config."""
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
return (
BaseVocabulary(
@ -133,15 +95,18 @@ class BaseVocabulary:
@property
def num_chars(self):
return max(self._vocab.values()) + 1
"""Return number of tokens in the vocabulary."""
return len(self._vocab)
def char_to_id(self, char: str) -> int:
"""Map a character to an token ID."""
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
"""Map an token ID to a character."""
return self._id_to_char[idx]

View File

@ -1,9 +1,53 @@
import unittest
from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes
from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary
# pylint: disable=protected-access
class BaseVocabularyTest(unittest.TestCase):
def setUp(self):
self.phonemes = IPAPhonemes()
self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos)
self.empty_vocab = BaseVocabulary({})
def test_pad_id(self):
self.assertEqual(self.empty_vocab.pad_id, 0)
self.assertEqual(self.base_vocab.pad_id, self.phonemes.pad_id)
def test_blank_id(self):
self.assertEqual(self.empty_vocab.blank_id, 0)
self.assertEqual(self.base_vocab.blank_id, self.phonemes.blank_id)
def test_vocab(self):
self.assertEqual(self.empty_vocab.vocab, {})
self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab)
def test_init_from_config(self):
...
def test_num_chars(self):
self.assertEqual(self.empty_vocab.num_chars, 0)
self.assertEqual(self.base_vocab.num_chars, self.phonemes.num_chars)
def test_char_to_id(self):
try:
self.empty_vocab.char_to_id("a")
raise Exception("Should have raised KeyError")
except:
pass
for k in self.phonemes.vocab:
self.assertEqual(self.base_vocab.char_to_id(k), self.phonemes.char_to_id(k))
def test_id_to_char(self):
try:
self.empty_vocab.id_to_char(0)
raise Exception("Should have raised KeyError")
except:
pass
for k in self.phonemes.vocab:
v = self.phonemes.char_to_id(k)
self.assertEqual(self.base_vocab.id_to_char(v), self.phonemes.id_to_char(v))
class BaseCharacterTest(unittest.TestCase):
def setUp(self):