From 577ec406f4912b4d1d524253b6a5294694277418 Mon Sep 17 00:00:00 2001 From: manmay nakhashi Date: Wed, 22 Jun 2022 15:37:46 +0530 Subject: [PATCH] Fix checkpointing GAN models (#1641) * checkpoint sae step crash fix * checkpoint save step crash fix * Update gan.py updated requested changes * crash fix --- TTS/vocoder/models/gan.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index ed5b26dd..a3803f77 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -185,8 +185,7 @@ class GAN(BaseVocoder): outputs = {"model_outputs": self.y_hat_g} return outputs, loss_dict - @staticmethod - def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: """Logging shared by the training and evaluation. Args: @@ -198,7 +197,7 @@ class GAN(BaseVocoder): Returns: Tuple[Dict, Dict]: log figures and audio samples. """ - y_hat = outputs[0]["model_outputs"] + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] y = batch["waveform"] figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()