diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index f5d0a33e..367efdc2 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -89,14 +89,14 @@ class GAN(BaseVocoder): if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") - if optimizer_idx == 0: # DISCRIMINATOR optimization # generator pass y_hat = self.model_g(x)[:, :, : y.size(2)] - + # cache for generator loss + # pylint: disable=W0201 self.y_hat_g = y_hat self.y_hat_sub = None self.y_sub_g = None @@ -178,7 +178,9 @@ class GAN(BaseVocoder): feats_fake, feats_real = None, None # compute losses - loss_dict = criterion[optimizer_idx](self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g) + loss_dict = criterion[optimizer_idx]( + self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g + ) outputs = {"model_outputs": self.y_hat_g} return outputs, loss_dict