mirror of https://github.com/coqui-ai/TTS.git
Update Characters and add more tests
This commit is contained in:
parent
7de5afc29a
commit
690de1ab06
|
@ -35,51 +35,6 @@ _diacrilics = "ɚ˞ɫ"
|
|||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
|
||||
# def create_graphemes(
|
||||
# characters=_characters,
|
||||
# punctuations=_punctuations,
|
||||
# pad=_pad,
|
||||
# eos=_eos,
|
||||
# bos=_bos,
|
||||
# blank=_blank,
|
||||
# unique=True,
|
||||
# ): # pylint: disable=redefined-outer-name
|
||||
# """Function to create default characters and phonemes"""
|
||||
# # create graphemes
|
||||
# = (
|
||||
# sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
||||
# ) # this is to keep previous models compatible.
|
||||
# _graphemes = list(characters)
|
||||
# _graphemes = [bos] + _graphemes if len(bos) > 0 and bos is not None else _graphemes
|
||||
# _graphemes = [eos] + _graphemes if len(bos) > 0 and eos is not None else _graphemes
|
||||
# _graphemes = [pad] + _graphemes if len(bos) > 0 and pad is not None else _graphemes
|
||||
# _graphemes = [blank] + _graphemes if len(bos) > 0 and blank is not None else _graphemes
|
||||
# _graphemes = _graphemes + list(punctuations)
|
||||
# return _graphemes, _phonemes
|
||||
|
||||
|
||||
# def create_phonemes(
|
||||
# phonemes=_phonemes, punctuations=_punctuations, pad=_pad, eos=_eos, bos=_bos, blank=_blank, unique=True
|
||||
# ):
|
||||
# # create phonemes
|
||||
# _phonemes = None
|
||||
# _phonemes_sorted = (
|
||||
# sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
||||
# ) # this is to keep previous models compatible.
|
||||
# _phonemes = list(_phonemes_sorted)
|
||||
# _phonemes = [bos] + _phonemes if len(bos) > 0 and bos is not None else _phonemes
|
||||
# _phonemes = [eos] + _phonemes if len(bos) > 0 and eos is not None else _phonemes
|
||||
# _phonemes = [pad] + _phonemes if len(bos) > 0 and pad is not None else _phonemes
|
||||
# _phonemes = [blank] + _phonemes if len(bos) > 0 and blank is not None else _phonemes
|
||||
# _phonemes = _phonemes + list(punctuations)
|
||||
# _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
# return _phonemes
|
||||
|
||||
|
||||
# DEF_GRAPHEMES = create_graphemes(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
|
||||
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
|
||||
|
||||
|
||||
class BaseVocabulary:
|
||||
"""Base Vocabulary class.
|
||||
|
||||
|
@ -98,18 +53,24 @@ class BaseVocabulary:
|
|||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Return the index of the padding character. If the padding character is not specified, return the length
|
||||
of the vocabulary."""
|
||||
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||
|
||||
@property
|
||||
def blank_id(self) -> int:
|
||||
"""Return the index of the blank character. If the blank character is not specified, return the length of
|
||||
the vocabulary."""
|
||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
"""Return the vocabulary dictionary."""
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
"""Set the vocabulary dictionary and character mapping dictionaries."""
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
|
@ -118,6 +79,7 @@ class BaseVocabulary:
|
|||
|
||||
@staticmethod
|
||||
def init_from_config(config, **kwargs):
|
||||
"""Initialize from the given config."""
|
||||
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
|
||||
return (
|
||||
BaseVocabulary(
|
||||
|
@ -133,15 +95,18 @@ class BaseVocabulary:
|
|||
|
||||
@property
|
||||
def num_chars(self):
|
||||
return max(self._vocab.values()) + 1
|
||||
"""Return number of tokens in the vocabulary."""
|
||||
return len(self._vocab)
|
||||
|
||||
def char_to_id(self, char: str) -> int:
|
||||
"""Map a character to an token ID."""
|
||||
try:
|
||||
return self._char_to_id[char]
|
||||
except KeyError as e:
|
||||
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
||||
|
||||
def id_to_char(self, idx: int) -> str:
|
||||
"""Map an token ID to a character."""
|
||||
return self._id_to_char[idx]
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,53 @@
|
|||
import unittest
|
||||
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
class BaseVocabularyTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.phonemes = IPAPhonemes()
|
||||
self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos)
|
||||
self.empty_vocab = BaseVocabulary({})
|
||||
|
||||
def test_pad_id(self):
|
||||
self.assertEqual(self.empty_vocab.pad_id, 0)
|
||||
self.assertEqual(self.base_vocab.pad_id, self.phonemes.pad_id)
|
||||
|
||||
def test_blank_id(self):
|
||||
self.assertEqual(self.empty_vocab.blank_id, 0)
|
||||
self.assertEqual(self.base_vocab.blank_id, self.phonemes.blank_id)
|
||||
|
||||
def test_vocab(self):
|
||||
self.assertEqual(self.empty_vocab.vocab, {})
|
||||
self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab)
|
||||
|
||||
def test_init_from_config(self):
|
||||
...
|
||||
|
||||
def test_num_chars(self):
|
||||
self.assertEqual(self.empty_vocab.num_chars, 0)
|
||||
self.assertEqual(self.base_vocab.num_chars, self.phonemes.num_chars)
|
||||
|
||||
def test_char_to_id(self):
|
||||
try:
|
||||
self.empty_vocab.char_to_id("a")
|
||||
raise Exception("Should have raised KeyError")
|
||||
except:
|
||||
pass
|
||||
for k in self.phonemes.vocab:
|
||||
self.assertEqual(self.base_vocab.char_to_id(k), self.phonemes.char_to_id(k))
|
||||
|
||||
def test_id_to_char(self):
|
||||
try:
|
||||
self.empty_vocab.id_to_char(0)
|
||||
raise Exception("Should have raised KeyError")
|
||||
except:
|
||||
pass
|
||||
for k in self.phonemes.vocab:
|
||||
v = self.phonemes.char_to_id(k)
|
||||
self.assertEqual(self.base_vocab.id_to_char(v), self.phonemes.id_to_char(v))
|
||||
|
||||
|
||||
class BaseCharacterTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue