From d9df33f8376c135b9dc21e75f7ceeebd31a36b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:18:10 +0000 Subject: [PATCH] Update `align_tts` for trainer_v2 --- TTS/tts/models/align_tts.py | 21 ++++++++++---- recipes/ljspeech/align_tts/train_aligntts.py | 29 ++++++++++++++++++-- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 78fbaeab..3b0a848d 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -103,7 +103,7 @@ class AlignTTS(BaseTTS): def __init__(self, config: Coqpit): - super().__init__() + super().__init__(config) self.config = config self.phase = -1 self.length_scale = ( @@ -360,9 +360,7 @@ class AlignTTS(BaseTTS): return outputs, loss_dict - def train_log( - self, ap: AudioProcessor, batch: dict, outputs: dict - ) -> Tuple[Dict, Dict]: # pylint: disable=no-self-use + def _create_logs(self, batch, outputs, ap): model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] @@ -381,11 +379,22 @@ class AlignTTS(BaseTTS): train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, ap.sample_rate) + def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) - def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): - return self.train_log(ap, batch, outputs) + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 4e214f92..76409374 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,9 +1,14 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.align_tts import AlignTTS +from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") ) @@ -25,6 +30,24 @@ config = AlignTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init model +model = AlignTTS(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()