diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index 392447de..ff1b99e8 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -106,7 +106,6 @@ class InvConvNear(nn.Module): - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` """ - b, c, t = x.size() assert c % self.num_splits == 0 if x_mask is None: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2e94659e..bcc46cec 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,4 +1,5 @@ import math +from typing import Dict, Tuple import torch from torch import nn @@ -47,7 +48,7 @@ class GlowTTS(BaseTTS): def __init__(self, config: GlowTTSConfig): - super().__init__() + super().__init__(config) # pass all config fields to `self` # for fewer code change @@ -387,7 +388,7 @@ class GlowTTS(BaseTTS): ) return outputs, loss_dict - def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + def _create_logs(self, batch, outputs, ap): alignments = outputs["alignments"] text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -416,15 +417,26 @@ class GlowTTS(BaseTTS): train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, ap.sample_rate) + @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) - def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): - return self.train_log(ap, batch, outputs) + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, ap.sample_rate) @torch.no_grad() - def test_run(self, ap): + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -432,6 +444,7 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ + ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 5d71f4ed..29077eeb 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -1,7 +1,10 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.glow_tts import GlowTTS +from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( @@ -25,6 +28,24 @@ config = GlowTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init model +model = GlowTTS(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()