From 3d5205d66ff84bdb82b0f18f3c47064226a43193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:21:25 +0000 Subject: [PATCH] Update WaveGrad --- TTS/vocoder/models/wavegrad.py | 18 +++++++------ recipes/ljspeech/wavegrad/train_wavegrad.py | 28 ++++++++++++++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 8d95a063..5755a9a7 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -58,7 +58,7 @@ class Wavegrad(BaseVocoder): # pylint: disable=dangerous-default-value def __init__(self, config: Coqpit): - super().__init__() + super().__init__(config) self.config = config self.use_weight_norm = config.model_params.use_weight_norm self.hop_len = np.prod(config.model_params.upsample_factors) @@ -258,21 +258,22 @@ class Wavegrad(BaseVocoder): return {"model_output": noise_hat}, {"loss": loss} def train_log( # pylint: disable=no-self-use - self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: - return None, None + pass @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: return self.train_step(batch, criterion) def eval_log( # pylint: disable=no-self-use - self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument - ) -> Tuple[Dict, np.ndarray]: - return None, None + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> None: + pass - def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument + def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument # setup noise schedule and inference + ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) @@ -307,8 +308,9 @@ class Wavegrad(BaseVocoder): return {"input": m, "waveform": y} def get_data_loader( - self, config: Coqpit, ap: AudioProcessor, is_eval: True, data_items: List, verbose: bool, num_gpus: int + self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int ): + ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, items=data_items, diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index fe038915..aa873169 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,7 +1,11 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavegradConfig +from TTS.vocoder.models.wavegrad import Wavegrad +from TTS.vocoder.datasets.preprocess import load_wav_data + output_path = os.path.dirname(os.path.abspath(__file__)) config = WavegradConfig( @@ -22,6 +26,24 @@ config = WavegradConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + +# init model +model = Wavegrad(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()