From 0de38c261769a1f184c906589882a8fe79951a4a Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 16 Jun 2020 12:34:10 +0200 Subject: [PATCH] fixing naming convention in vocoder losses --- vocoder/layers/losses.py | 14 +++++++------- vocoder/utils/generic_utils.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index 22237077..fe13fa8a 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -87,17 +87,17 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): class MSEGLoss(nn.Module): """ Mean Squared Generator Loss """ # pylint: disable=no-self-use - def forward(self, score_fake): - loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape)) + def forward(self, score_real): + loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) return loss_fake class HingeGLoss(nn.Module): """ Hinge Discriminator Loss """ # pylint: disable=no-self-use - def forward(self, score_fake): + def forward(self, score_real): # TODO: this might be wrong - loss_fake = torch.mean(F.relu(1. - score_fake)) + loss_fake = torch.mean(F.relu(1. - score_real)) return loss_fake @@ -172,7 +172,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func): if isinstance(scores_fake, list): # multi-scale loss for score_fake, score_real in zip(scores_fake, scores_real): - total_loss, real_loss, fake_loss = loss_func(score_fake, score_real) + total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real) loss += total_loss real_loss += real_loss fake_loss += fake_loss @@ -286,14 +286,14 @@ class DiscriminatorLoss(nn.Module): return_dict = {} if self.use_mse_gan_loss: - mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss) + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss) return_dict['D_mse_gan_loss'] = mse_D_loss return_dict['D_mse_gan_real_loss'] = mse_D_real_loss return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss loss += mse_D_loss if self.use_hinge_gan_loss: - hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss) + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss) return_dict['D_hinge_gan_loss'] = hinge_D_loss return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss diff --git a/vocoder/utils/generic_utils.py b/vocoder/utils/generic_utils.py index 179c4064..80c97f1a 100644 --- a/vocoder/utils/generic_utils.py +++ b/vocoder/utils/generic_utils.py @@ -29,9 +29,9 @@ def plot_results(y_hat, y, ap, global_step, name_prefix): plt.close() figures = { - name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap), - name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap), - name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap), + name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), + name_prefix + "spectrogram/real": plot_spectrogram(spec_real), + name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), name_prefix + "speech_comparison": fig_wave, } return figures