From b2e9c05acc7fc5f934d5b5785b13ed1855543207 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 16 Jun 2020 13:24:59 +0200 Subject: [PATCH] losses bug fix --- vocoder/layers/losses.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index fe13fa8a..d6ffe9fe 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -286,14 +286,20 @@ class DiscriminatorLoss(nn.Module): return_dict = {} if self.use_mse_gan_loss: - mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss) + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, + scores_real=scores_real, + loss_func=self.mse_loss) return_dict['D_mse_gan_loss'] = mse_D_loss return_dict['D_mse_gan_real_loss'] = mse_D_real_loss return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss loss += mse_D_loss if self.use_hinge_gan_loss: - hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss) + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, + scores_real=scores_real, + loss_func=self.hinge_loss) return_dict['D_hinge_gan_loss'] = hinge_D_loss return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss