From 7c4243fba7738748006ee2ac2e806812616f02a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:31 +0000 Subject: [PATCH] Update GlowTTS --- TTS/tts/models/glow_tts.py | 48 ++++++++++++++------------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 73680f32..7a48b023 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,5 +1,5 @@ import math -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -50,8 +50,8 @@ class GlowTTS(BaseTTS): def __init__( self, config: GlowTTSConfig, - ap: "AudioProcessor", - tokenizer: "TTSTokenizer", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): @@ -63,7 +63,6 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - self.num_chars = self.tokenizer.characters.num_chars self.decoder_output_dim = config.out_channels # init multi-speaker layers if necessary @@ -429,20 +428,18 @@ class GlowTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: @@ -467,19 +464,16 @@ class GlowTTS(BaseTTS): sen, self.config, "cuda" in str(next(self.parameters()).device), - self.ap, - self.tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs["outputs"]["model_outputs"], ap, output_fig=False + outputs["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return test_figures, test_audios @@ -516,23 +510,17 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: Coqpit): - """Initialize model from config.""" - - # init characters - if config.use_phonemes: - from TTS.tts.utils.text.characters import IPAPhonemes - - characters = IPAPhonemes().init_from_config(config) - else: - from TTS.tts.utils.text.characters import Graphemes - - characters = Graphemes().init_from_config(config) - config.num_chars = characters.num_chars + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ from TTS.utils.audio import AudioProcessor ap = AudioProcessor.init_from_config(config) - tokenizer = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config) - return GlowTTS(config, ap, tokenizer, speaker_manager) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager)