From 18f726af6594de93cd44c4c1bb7b9ccc037c3f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:16 +0000 Subject: [PATCH] Update ForwardTTS --- TTS/tts/models/base_tts.py | 19 +++++++---------- TTS/tts/models/forward_tts.py | 40 ++++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 27231790..59862322 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,6 +1,6 @@ import os import random -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -56,9 +56,10 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: - num_chars = ( - self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars if "characters" in config: self.config.num_chars = num_chars if hasattr(self.config, "model_args"): @@ -237,7 +238,7 @@ class BaseTTS(BaseModel): config: Coqpit, assets: Dict, is_eval: bool, - data_items: List, + samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, @@ -274,7 +275,7 @@ class BaseTTS(BaseModel): compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), - meta_data=data_items, + samples=samples, ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, @@ -283,6 +284,7 @@ class BaseTTS(BaseModel): min_audio_len=config.min_audio_len, max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, @@ -357,8 +359,6 @@ class BaseTTS(BaseModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] - tokenizer = assets["tokenizer"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -370,18 +370,15 @@ class BaseTTS(BaseModel): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, - tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs_dict["outputs"]["model_outputs"], ap, output_fig=False + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment( outputs_dict["outputs"]["alignments"], output_fig=False diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b2c41df5..699f3142 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram @@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS): """ # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) - self.speaker_manager = speaker_manager self.init_multispeaker(config) self.max_duration = self.args.max_duration @@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS): """Enable binary alignment loss when needed""" if trainer.total_steps_done > self.config.binary_align_loss_start_step: self.use_binary_alignment_loss = True + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager)