diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index ed5b26dd..a3803f77 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -185,8 +185,7 @@ class GAN(BaseVocoder): outputs = {"model_outputs": self.y_hat_g} return outputs, loss_dict - @staticmethod - def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: """Logging shared by the training and evaluation. Args: @@ -198,7 +197,7 @@ class GAN(BaseVocoder): Returns: Tuple[Dict, Dict]: log figures and audio samples. """ - y_hat = outputs[0]["model_outputs"] + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] y = batch["waveform"] figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()