mirror of https://github.com/coqui-ai/TTS.git
Implement BaseVocabulary
This commit is contained in:
parent
833de62e30
commit
2bad098625
|
@ -1,4 +1,6 @@
|
||||||
|
from abc import ABC
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
|
|
||||||
|
@ -79,6 +81,71 @@ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprase
|
||||||
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
|
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVocabulary:
|
||||||
|
"""Base Vocabulary class.
|
||||||
|
|
||||||
|
This class only needs a vocabulary dictionary without specifying the characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab (Dict): A dictionary of characters and their corresponding indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.pad = pad
|
||||||
|
self.blank = blank
|
||||||
|
self.bos = bos
|
||||||
|
self.eos = eos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self) -> int:
|
||||||
|
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def blank_id(self) -> int:
|
||||||
|
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab(self):
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
@vocab.setter
|
||||||
|
def vocab(self, vocab):
|
||||||
|
self._vocab = vocab
|
||||||
|
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||||
|
self._id_to_char = {
|
||||||
|
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config, **kwargs):
|
||||||
|
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
|
||||||
|
return (
|
||||||
|
BaseVocabulary(
|
||||||
|
config.characters.vocab_dict,
|
||||||
|
config.characters.pad,
|
||||||
|
config.characters.blank,
|
||||||
|
config.characters.bos,
|
||||||
|
config.characters.eos,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
return BaseVocabulary(**kwargs), config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_chars(self):
|
||||||
|
return max(self._vocab.values()) + 1
|
||||||
|
|
||||||
|
def char_to_id(self, char: str) -> int:
|
||||||
|
try:
|
||||||
|
return self._char_to_id[char]
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
||||||
|
|
||||||
|
def id_to_char(self, idx: int) -> str:
|
||||||
|
return self._id_to_char[idx]
|
||||||
|
|
||||||
|
|
||||||
class BaseCharacters:
|
class BaseCharacters:
|
||||||
"""🐸BaseCharacters class
|
"""🐸BaseCharacters class
|
||||||
|
|
||||||
|
@ -116,12 +183,12 @@ class BaseCharacters:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
characters: str,
|
characters: str = None,
|
||||||
punctuations: str,
|
punctuations: str = None,
|
||||||
pad: str,
|
pad: str = None,
|
||||||
eos: str,
|
eos: str = None,
|
||||||
bos: str,
|
bos: str = None,
|
||||||
blank: str,
|
blank: str = None,
|
||||||
is_unique: bool = False,
|
is_unique: bool = False,
|
||||||
is_sorted: bool = True,
|
is_sorted: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -135,6 +202,14 @@ class BaseCharacters:
|
||||||
self.is_sorted = is_sorted
|
self.is_sorted = is_sorted
|
||||||
self._create_vocab()
|
self._create_vocab()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self) -> int:
|
||||||
|
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def blank_id(self) -> int:
|
||||||
|
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def characters(self):
|
def characters(self):
|
||||||
return self._characters
|
return self._characters
|
||||||
|
@ -193,6 +268,14 @@ class BaseCharacters:
|
||||||
def vocab(self):
|
def vocab(self):
|
||||||
return self._vocab
|
return self._vocab
|
||||||
|
|
||||||
|
@vocab.setter
|
||||||
|
def vocab(self, vocab):
|
||||||
|
self._vocab = vocab
|
||||||
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
|
self._id_to_char = {
|
||||||
|
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_chars(self):
|
def num_chars(self):
|
||||||
return len(self._vocab)
|
return len(self._vocab)
|
||||||
|
@ -208,11 +291,7 @@ class BaseCharacters:
|
||||||
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
|
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
|
||||||
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
|
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
|
||||||
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
|
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
|
||||||
self._vocab = _vocab + list(self._punctuations)
|
self.vocab = _vocab + list(self._punctuations)
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
|
||||||
self._id_to_char = {
|
|
||||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
|
||||||
}
|
|
||||||
if self.is_unique:
|
if self.is_unique:
|
||||||
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
|
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
|
||||||
assert (
|
assert (
|
||||||
|
@ -248,7 +327,13 @@ class BaseCharacters:
|
||||||
|
|
||||||
Implement this method for your subclass.
|
Implement this method for your subclass.
|
||||||
"""
|
"""
|
||||||
...
|
# use character set from config
|
||||||
|
if config.characters is not None:
|
||||||
|
return BaseCharacters(**config.characters), config
|
||||||
|
# return default character set
|
||||||
|
characters = BaseCharacters()
|
||||||
|
new_config = replace(config, characters=characters.to_config())
|
||||||
|
return characters, new_config
|
||||||
|
|
||||||
def to_config(self) -> "CharactersConfig":
|
def to_config(self) -> "CharactersConfig":
|
||||||
return CharactersConfig(
|
return CharactersConfig(
|
||||||
|
|
|
@ -30,20 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||||
return grad_norm, skip_flag
|
return grad_norm, skip_flag
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
|
||||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
|
||||||
self.warmup_steps = float(warmup_steps)
|
|
||||||
super().__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
step = max(self.last_epoch, 1)
|
|
||||||
return [
|
|
||||||
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
|
|
||||||
for base_lr in self.base_lrs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def gradual_training_scheduler(global_step, config):
|
def gradual_training_scheduler(global_step, config):
|
||||||
"""Setup the gradual training schedule wrt number
|
"""Setup the gradual training schedule wrt number
|
||||||
of active GPUs"""
|
of active GPUs"""
|
||||||
|
|
Loading…
Reference in New Issue