From bacf79f4fbb87500a43d9dbf48e466cc4d35b77a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 12:56:24 +0000 Subject: [PATCH] Update AlignTTS --- TTS/tts/models/align_tts.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2fc00b0b..c1e2ffb3 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Dict, List, Union import torch from coqpit import Coqpit @@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -100,11 +102,16 @@ class AlignTTS(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) self.speaker_manager = speaker_manager - self.config = config self.phase = -1 self.length_scale = ( float(config.model_args.length_scale) @@ -112,10 +119,6 @@ class AlignTTS(BaseTTS): else config.model_args.length_scale ) - if not self.config.model_args.num_chars: - _, self.config, num_chars = self.get_characters(config) - self.config.model_args.num_chars = num_chars - self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.embedded_speaker_dim = 0 @@ -382,19 +385,17 @@ class AlignTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -430,3 +431,19 @@ class AlignTTS(BaseTTS): def on_epoch_start(self, trainer): """Set AlignTTS training phase on epoch start.""" self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager)