diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index f203c533..81ba87c4 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -35,7 +35,7 @@ class GAN(BaseVocoder): >>> config = HifiganConfig() >>> model = GAN(config) """ - super().__init__() + super().__init__(config) self.config = config self.model_g = setup_generator(config) self.model_d = setup_discriminator(config) @@ -197,18 +197,24 @@ class GAN(BaseVocoder): audios = {f"{name}/audio": sample_voice} return figures, audios - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def train_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for training.""" - return self._log("train", ap, batch, outputs) + ap = assets["audio_processor"] + self._log("train", ap, batch, outputs) @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Call `train_step()` with `no_grad()`""" return self.train_step(batch, criterion, optimizer_idx) - def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for evaluation.""" - return self._log("eval", ap, batch, outputs) + ap = assets["audio_processor"] + self._log("eval", ap, batch, outputs) def load_checkpoint( self, @@ -299,7 +305,7 @@ class GAN(BaseVocoder): def get_data_loader( # pylint: disable=no-self-use self, config: Coqpit, - ap: AudioProcessor, + assets: Dict, is_eval: True, data_items: List, verbose: bool, @@ -318,6 +324,7 @@ class GAN(BaseVocoder): Returns: DataLoader: Torch dataloader. """ + ap = assets["audio_processor"] dataset = GANDataset( ap=ap, items=data_items, diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index f50ef476..8d1c272a 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -1,29 +1,51 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import HifiganConfig +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.models.gan import GAN output_path = os.path.dirname(os.path.abspath(__file__)) + config = HifiganConfig( batch_size=32, eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, run_eval=True, - test_delay_epochs=-1, + test_delay_epochs=5, epochs=1000, seq_len=8192, pad_short=2000, use_noise_augment=True, eval_split_size=10, print_step=25, - print_eval=True, + print_eval=False, mixed_precision=False, lr_gen=1e-4, lr_disc=1e-4, data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + +# init model +model = GAN(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit() diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 1473ec3c..90c52997 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,29 +1,51 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.models.gan import GAN output_path = os.path.dirname(os.path.abspath(__file__)) + config = MultibandMelganConfig( batch_size=32, eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, run_eval=True, - test_delay_epochs=-1, + test_delay_epochs=5, epochs=1000, seq_len=8192, pad_short=2000, use_noise_augment=True, eval_split_size=10, print_step=25, - print_eval=True, + print_eval=False, mixed_precision=False, lr_gen=1e-4, lr_disc=1e-4, data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + +# init model +model = GAN(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit() diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index e8979c92..a4ab93bf 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,7 +1,10 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import UnivnetConfig +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.models.gan import GAN output_path = os.path.dirname(os.path.abspath(__file__)) config = UnivnetConfig( @@ -24,6 +27,24 @@ config = UnivnetConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + +# init model +model = GAN(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()