Update imports for symbols -> characters

This commit is contained in:
Eren Gölge 2021-11-17 12:46:04 +01:00
parent a1df4f9887
commit fbad17e084
10 changed files with 26 additions and 337 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_checkpoint

View File

@ -4,7 +4,7 @@ from coqpit import Coqpit
from TTS.config import load_config, register_config
from TTS.trainer import TrainingArgs
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
from TTS.utils.logging import init_dashboard_logger

View File

@ -1,4 +1,4 @@
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.tts.utils.text.characters import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
@ -17,7 +17,7 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manag
else:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
from TTS.tts.utils.text.characters import phonemes, symbols # pylint: disable=import-outside-toplevel
if config.use_phonemes:
symbols = phonemes

View File

@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
from TTS.tts.utils.text.characters import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
@ -34,7 +34,9 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
):
super().__init__(config)
self.config = config
self.ap = ap
@ -56,7 +58,9 @@ class BaseTTS(BaseModel):
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
num_chars = (
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
)
if "characters" in config:
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
@ -76,7 +80,7 @@ class BaseTTS(BaseModel):
# if config.characters is not None:
# symbols, phonemes = make_symbols(**config.characters)
# else:
# from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
# from TTS.tts.utils.text.characters import parse_symbols, phonemes, symbols
# if config.use_phonemes:
@ -306,7 +310,7 @@ class BaseTTS(BaseModel):
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer
tokenizer=self.tokenizer,
)
# pre-compute phonemes

View File

@ -46,7 +46,13 @@ class GlowTTS(BaseTTS):
"""
def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
def __init__(
self,
config: GlowTTSConfig,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)

View File

@ -5,7 +5,6 @@ import torch
from torch import nn
def numpy_to_torch(np_array, dtype, cuda=False):
if np_array is None:
return None

View File

@ -32,7 +32,6 @@ class BasePhonemizer(abc.ABC):
self._keep_puncs = keep_puncs
self._punctuator = Punctuation(punctuations)
def _init_language(self, language):
"""Language initialization
@ -130,7 +129,7 @@ class BasePhonemizer(abc.ABC):
phonemized = self._phonemize_postprocess(phonemized, punctuations)
return phonemized
def print_logs(self, level: int=0):
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}")
print(f"{indent}| > phoneme backend: {self.name()}")

View File

@ -1,316 +0,0 @@
# -*- coding: utf-8 -*-
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
"""
def parse_symbols():
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
def make_symbols(
characters,
phonemes=None,
punctuations="!'(),-.:;? ",
pad="<PAD>",
eos="<EOS>",
bos="<BOS>",
blank="<BLNK>",
unique=True,
): # pylint: disable=redefined-outer-name
"""Function to create default characters and phonemes"""
_symbols = list(characters)
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
_symbols = [blank] + _symbols if len(bos) > 0 and blank is not None else _symbols
_phonemes = None
if phonemes is not None:
_phonemes_sorted = (
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
) # this is to keep previous models compatible.
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
# _symbols += _arpabet
return _symbols, _phonemes
_pad = "<PAD>"
_eos = "<EOS>"
_bos = "<BOS>"
_blank = "<BLNK>" # TODO: check if we need this alongside with PAD
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
_punctuations = "!'(),-.:;? "
# Phonemes definition (All IPA characters)
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
class BaseCharacters:
"""🐸BaseCharacters class
Every vocabulary class should inherit from this class.
Args:
characters (str):
Main set of characters to be used in the vocabulary.
punctuations (str):
Characters to be treated as punctuation.
pad (str):
Special padding character that would be ignored by the model.
eos (str):
End of the sentence character.
bos (str):
Beginning of the sentence character.
blank (str):
Optional character used between characters by some models for better prosody.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str,
punctuations: str,
pad: str,
eos: str,
bos: str,
blank: str,
is_unique: bool = True,
is_sorted: bool = True,
) -> None:
self._characters = characters
self._punctuations = punctuations
self._pad = pad
self._eos = eos
self._bos = bos
self._blank = blank
self.is_unique = is_unique
self.is_sorted = is_sorted
self._create_vocab()
@property
def characters(self):
return self._characters
@characters.setter
def characters(self, characters):
self._characters = characters
self._vocab = self.create_vocab()
@property
def punctuations(self):
return self._punctuations
@punctuations.setter
def punctuations(self, punctuations):
self._punctuations = punctuations
self._vocab = self.create_vocab()
@property
def pad(self):
return self._pad
@pad.setter
def pad(self, pad):
self._pad = pad
self._vocab = self.create_vocab()
@property
def eos(self):
return self._eos
@eos.setter
def eos(self, eos):
self._eos = eos
self._vocab = self.create_vocab()
@property
def bos(self):
return self._bos
@bos.setter
def bos(self, bos):
self._bos = bos
self._vocab = self.create_vocab()
@property
def blank(self):
return self._bos
@bos.setter
def blank(self, bos):
self._bos = bos
self._vocab = self.create_vocab()
@property
def vocab(self):
return self._vocab
@property
def num_chars(self):
return len(self._vocab)
def _create_vocab(self):
_vocab = self.characters
if self.is_unique:
_vocab = list(set(_vocab))
if self.is_sorted:
_vocab = sorted(_vocab)
_vocab = list(_vocab)
_vocab = [self.bos] + _vocab if len(self.bos) > 0 and self.bos is not None else _vocab
_vocab = [self.eos] + _vocab if len(self.bos) > 0 and self.eos is not None else _vocab
_vocab = [self.pad] + _vocab if len(self.bos) > 0 and self.pad is not None else _vocab
self._vocab = _vocab + list(self._punctuations)
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
assert len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
def char_to_id(self, char: str) -> int:
return self._char_to_id[char]
def id_to_char(self, idx: int) -> str:
return self._id_to_char[idx]
@staticmethod
def init_from_config(config: "Coqpit"):
return BaseCharacters(
**config.characters if config.characters is not None else {},
)
class IPAPhonemes(BaseCharacters):
"""🐸IPAPhonemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using IPAPhonemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _phonemes,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
is_unique: bool = True,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
return IPAPhonemes(
**config.characters if config.characters is not None else {},
)
class Graphemes(BaseCharacters):
"""🐸Graphemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using graphemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _characters,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
is_unique: bool = True,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
return Graphemes(
**config.characters if config.characters is not None else {},
)
if __name__ == "__main__":
gr = Graphemes()
ph = IPAPhonemes()
print(gr.vocab)
print(ph.vocab)
print(gr.num_chars)
assert "a" == gr.id_to_char(gr.char_to_id("a"))

View File

@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.tts.utils.text.symbols import Graphemes, IPAPhonemes
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
class TTSTokenizer:
@ -117,4 +117,6 @@ class TTSTokenizer:
phonemizer = get_phonemizer_by_name(DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs)
else:
characters = Graphemes().init_from_config(config)
return TTSTokenizer(config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars)
return TTSTokenizer(
config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars
)

View File

@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀