Implement BaseVocabulary

This commit is contained in:
Eren Gölge 2022-02-20 11:47:42 +01:00
parent 833de62e30
commit 2bad098625
2 changed files with 97 additions and 26 deletions

View File

@ -1,4 +1,6 @@
from abc import ABC
from dataclasses import replace
from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
@ -79,6 +81,71 @@ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprase
# DEF_PHONEMES = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
class BaseVocabulary:
"""Base Vocabulary class.
This class only needs a vocabulary dictionary without specifying the characters.
Args:
vocab (Dict): A dictionary of characters and their corresponding indices.
"""
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
self.vocab = vocab
self.pad = pad
self.blank = blank
self.bos = bos
self.eos = eos
@property
def pad_id(self) -> int:
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def vocab(self):
return self._vocab
@vocab.setter
def vocab(self, vocab):
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod
def init_from_config(config, **kwargs):
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
return (
BaseVocabulary(
config.characters.vocab_dict,
config.characters.pad,
config.characters.blank,
config.characters.bos,
config.characters.eos,
),
config,
)
return BaseVocabulary(**kwargs), config
@property
def num_chars(self):
return max(self._vocab.values()) + 1
def char_to_id(self, char: str) -> int:
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
return self._id_to_char[idx]
class BaseCharacters:
"""🐸BaseCharacters class
@ -116,12 +183,12 @@ class BaseCharacters:
def __init__(
self,
characters: str,
punctuations: str,
pad: str,
eos: str,
bos: str,
blank: str,
characters: str = None,
punctuations: str = None,
pad: str = None,
eos: str = None,
bos: str = None,
blank: str = None,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
@ -135,6 +202,14 @@ class BaseCharacters:
self.is_sorted = is_sorted
self._create_vocab()
@property
def pad_id(self) -> int:
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def characters(self):
return self._characters
@ -193,6 +268,14 @@ class BaseCharacters:
def vocab(self):
return self._vocab
@vocab.setter
def vocab(self, vocab):
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
@property
def num_chars(self):
return len(self._vocab)
@ -208,11 +291,7 @@ class BaseCharacters:
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
self._vocab = _vocab + list(self._punctuations)
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
self.vocab = _vocab + list(self._punctuations)
if self.is_unique:
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
assert (
@ -248,7 +327,13 @@ class BaseCharacters:
Implement this method for your subclass.
"""
...
# use character set from config
if config.characters is not None:
return BaseCharacters(**config.characters), config
# return default character set
characters = BaseCharacters()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(

View File

@ -30,20 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
return grad_norm, skip_flag
# pylint: disable=protected-access
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""