Refactor GlowTTS model and recipe for TTSTokenizer

This commit is contained in:
Eren Gölge 2021-11-16 13:36:35 +01:00
parent 5a9653978a
commit bd461ace33
4 changed files with 54 additions and 41 deletions

View File

@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC):
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit):
super().__init__() super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit): @staticmethod
"""Set model arguments from the config. Override this.""" def init_from_config(config: Coqpit):
"""Init the model from given config.
Override this depending on your model.
"""
pass pass
@abstractmethod @abstractmethod

View File

@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text.symbols import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file # pylint: skip-file
@ -34,8 +34,20 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1` - 1D tensors `batch x 1`
""" """
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
super().__init__(config)
self.config = config
self.ap = ap
self.tokenizer = tokenizer
self.speaker_manager = speaker_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit): def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type. """Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
If the config is for training with a name like "*Config", then the model args are embeded in the If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args config.model_args
@ -44,8 +56,8 @@ class BaseTTS(BaseModel):
""" """
# don't use isintance not to import recursively # don't use isintance not to import recursively
if "Config" in config.__class__.__name__: if "Config" in config.__class__.__name__:
num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config: if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars self.config.num_chars = num_chars
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars config.model_args.num_chars = num_chars
@ -58,18 +70,21 @@ class BaseTTS(BaseModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
@staticmethod # @staticmethod
def get_characters(config: Coqpit) -> str: # def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor # # TODO: implement CharacterProcessor
if config.characters is not None: # if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters) # symbols, phonemes = make_symbols(**config.characters)
else: # else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols # from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = CharactersConfig(**parse_symbols()) # if config.use_phonemes:
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False) # config.characters = Graphemes()
return model_characters, config, num_chars
# model_characters = phonemes if config.use_phonemes else symbols
# num_chars = len(model_characters) + getattr(config, "add_blank", False)
# return model_characters, config, num_chars
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path) return get_speaker_manager(config, restore_path, data, out_path)
@ -247,8 +262,6 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
else: else:
ap = assets["audio_processor"]
# setup multi-speaker attributes # setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"): if hasattr(config, "model_args"):
@ -279,28 +292,21 @@ class BaseTTS(BaseModel):
# init dataloader # init dataloader
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1, outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False), compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None), f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items, meta_data=data_items,
ap=ap, ap=self.ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
return_wav=config.return_wav if "return_wav" in config else False, return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len, min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
use_phonemes=config.use_phonemes,
phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars,
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
language_id_mapping=language_id_mapping, tokenizer=self.tokenizer
) )
# pre-compute phonemes # pre-compute phonemes
@ -332,7 +338,7 @@ class BaseTTS(BaseModel):
if config.compute_f0 and rank in [None, 0]: if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path): if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch( dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers self.ap, config.get("f0_cache_path", None), config.num_loader_workers
) )
# halt DDP processes for the main process to finish computing the F0 cache # halt DDP processes for the main process to finish computing the F0 cache
@ -404,6 +410,7 @@ class BaseTTS(BaseModel):
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"] ap = assets["audio_processor"]
tokenizer = assets["tokenizer"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -416,6 +423,7 @@ class BaseTTS(BaseModel):
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
ap, ap,
tokenizer,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],

View File

@ -46,11 +46,9 @@ class GlowTTS(BaseTTS):
""" """
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
super().__init__(config) super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
@ -58,7 +56,7 @@ class GlowTTS(BaseTTS):
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
_, self.config, self.num_chars = self.get_characters(config) self.num_chars = self.tokenizer.characters.num_chars
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary # init multi-speaker layers if necessary
@ -448,7 +446,6 @@ class GlowTTS(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -463,7 +460,8 @@ class GlowTTS(BaseTTS):
sen, sen,
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
ap, self.ap,
self.tokenizer,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],

View File

@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
# we use the same path as this script as our training folder. # we use the same path as this script as our training folder.
@ -47,7 +48,11 @@ config = GlowTTSConfig(
# INITIALIZE THE AUDIO PROCESSOR # INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O. # Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers. # It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor(**config.audio.to_dict()) ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
tokenizer = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES # LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]``` # Each sample is a list of ```[text, audio_file_path, speaker_name]```
@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# Models take a config object and a speaker manager as input # Models take a config object and a speaker manager as input
# Config defines the details of the model like the number of layers, the size of the embedding, etc. # Config defines the details of the model like the number of layers, the size of the embedding, etc.
# Speaker manager is used by multi-speaker models. # Speaker manager is used by multi-speaker models.
model = GlowTTS(config, speaker_manager=None) model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# INITIALIZE THE TRAINER # INITIALIZE THE TRAINER
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
@ -71,8 +76,7 @@ trainer = Trainer(
output_path, output_path,
model=model, model=model,
train_samples=train_samples, train_samples=train_samples,
eval_samples=eval_samples, eval_samples=eval_samples
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀