mirror of https://github.com/coqui-ai/TTS.git
Implement BaseVocabulary
This commit is contained in:
parent
833de62e30
commit
2bad098625
|
@ -1,4 +1,6 @@
|
|||
from abc import ABC
|
||||
from dataclasses import replace
|
||||
from typing import Dict
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
||||
|
@ -79,6 +81,71 @@ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprase
|
|||
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
|
||||
|
||||
|
||||
class BaseVocabulary:
|
||||
"""Base Vocabulary class.
|
||||
|
||||
This class only needs a vocabulary dictionary without specifying the characters.
|
||||
|
||||
Args:
|
||||
vocab (Dict): A dictionary of characters and their corresponding indices.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
|
||||
self.vocab = vocab
|
||||
self.pad = pad
|
||||
self.blank = blank
|
||||
self.bos = bos
|
||||
self.eos = eos
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||
|
||||
@property
|
||||
def blank_id(self) -> int:
|
||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config, **kwargs):
|
||||
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
|
||||
return (
|
||||
BaseVocabulary(
|
||||
config.characters.vocab_dict,
|
||||
config.characters.pad,
|
||||
config.characters.blank,
|
||||
config.characters.bos,
|
||||
config.characters.eos,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return BaseVocabulary(**kwargs), config
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
return max(self._vocab.values()) + 1
|
||||
|
||||
def char_to_id(self, char: str) -> int:
|
||||
try:
|
||||
return self._char_to_id[char]
|
||||
except KeyError as e:
|
||||
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
||||
|
||||
def id_to_char(self, idx: int) -> str:
|
||||
return self._id_to_char[idx]
|
||||
|
||||
|
||||
class BaseCharacters:
|
||||
"""🐸BaseCharacters class
|
||||
|
||||
|
@ -116,12 +183,12 @@ class BaseCharacters:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
characters: str,
|
||||
punctuations: str,
|
||||
pad: str,
|
||||
eos: str,
|
||||
bos: str,
|
||||
blank: str,
|
||||
characters: str = None,
|
||||
punctuations: str = None,
|
||||
pad: str = None,
|
||||
eos: str = None,
|
||||
bos: str = None,
|
||||
blank: str = None,
|
||||
is_unique: bool = False,
|
||||
is_sorted: bool = True,
|
||||
) -> None:
|
||||
|
@ -135,6 +202,14 @@ class BaseCharacters:
|
|||
self.is_sorted = is_sorted
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||
|
||||
@property
|
||||
def blank_id(self) -> int:
|
||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||
|
||||
@property
|
||||
def characters(self):
|
||||
return self._characters
|
||||
|
@ -193,6 +268,14 @@ class BaseCharacters:
|
|||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
return len(self._vocab)
|
||||
|
@ -208,11 +291,7 @@ class BaseCharacters:
|
|||
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
|
||||
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
|
||||
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
|
||||
self._vocab = _vocab + list(self._punctuations)
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
self.vocab = _vocab + list(self._punctuations)
|
||||
if self.is_unique:
|
||||
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
|
||||
assert (
|
||||
|
@ -248,7 +327,13 @@ class BaseCharacters:
|
|||
|
||||
Implement this method for your subclass.
|
||||
"""
|
||||
...
|
||||
# use character set from config
|
||||
if config.characters is not None:
|
||||
return BaseCharacters(**config.characters), config
|
||||
# return default character set
|
||||
characters = BaseCharacters()
|
||||
new_config = replace(config, characters=characters.to_config())
|
||||
return characters, new_config
|
||||
|
||||
def to_config(self) -> "CharactersConfig":
|
||||
return CharactersConfig(
|
||||
|
|
|
@ -30,20 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
|||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
self.warmup_steps = float(warmup_steps)
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
step = max(self.last_epoch, 1)
|
||||
return [
|
||||
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
def gradual_training_scheduler(global_step, config):
|
||||
"""Setup the gradual training schedule wrt number
|
||||
of active GPUs"""
|
||||
|
|
Loading…
Reference in New Issue