From 45889804c2e314bdfeccb7e405aefabf4b17424f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:22:11 +0000 Subject: [PATCH] Update VITS --- TTS/tts/models/vits.py | 43 ++++++++++++++----------- recipes/ljspeech/vits_tts/train_vits.py | 29 +++++++++++++++-- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 87695774..0ede3d13 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -217,7 +217,7 @@ class Vits(BaseTTS): def __init__(self, config: Coqpit): - super().__init__() + super().__init__(config) self.END2END = True @@ -576,22 +576,7 @@ class Vits(BaseTTS): ) return outputs, loss_dict - def train_log( - self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train" - ): # pylint: disable=no-self-use - """Create visualizations and waveform examples. - - For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to - be projected onto Tensorboard. - - Args: - ap (AudioProcessor): audio processor used at training. - batch (Dict): Model inputs used at the previous training step. - outputs (Dict): Model outputs generated at the previoud training step. - - Returns: - Tuple[Dict, np.ndarray]: training plots and output waveform. - """ + def _log(self, ap, batch, outputs, name_prefix="train"): y_hat = outputs[0]["model_outputs"] y = outputs[0]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) @@ -609,12 +594,32 @@ class Vits(BaseTTS): return figures, audios + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + ap = assets["audio_processor"] + self._log(ap, batch, outputs, "train") + @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) - def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): - return self.train_log(ap, batch, outputs, "eval") + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + ap = assets["audio_processor"] + return self._log(ap, batch, outputs, "eval") @torch.no_grad() def test_run(self, ap) -> Tuple[Dict, Dict]: diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 7cf52f89..3a2b1ef1 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -1,8 +1,12 @@ import os from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs import BaseDatasetConfig, VitsConfig +from TTS.tts.models.vits import Vits +from TTS.utils.audio import AudioProcessor +from TTS.tts.datasets import load_tts_samples + output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( @@ -24,6 +28,7 @@ audio_config = BaseAudioConfig( signal_norm=False, do_amp_to_db_linear=False, ) + config = VitsConfig( audio=audio_config, run_name="vits_ljspeech", @@ -47,6 +52,24 @@ config = VitsConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init model +model = Vits(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()