fixing naming convention in vocoder losses

This commit is contained in:
erogol 2020-06-16 12:34:10 +02:00
parent f18f6e6d3e
commit 0de38c2617
2 changed files with 10 additions and 10 deletions

View File

@ -87,17 +87,17 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake):
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape))
def forward(self, score_real):
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
return loss_fake
class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake):
def forward(self, score_real):
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1. - score_fake))
loss_fake = torch.mean(F.relu(1. - score_real))
return loss_fake
@ -172,7 +172,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
if isinstance(scores_fake, list):
# multi-scale loss
for score_fake, score_real in zip(scores_fake, scores_real):
total_loss, real_loss, fake_loss = loss_func(score_fake, score_real)
total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
loss += total_loss
real_loss += real_loss
fake_loss += fake_loss
@ -286,14 +286,14 @@ class DiscriminatorLoss(nn.Module):
return_dict = {}
if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss)
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss)
return_dict['D_mse_gan_loss'] = mse_D_loss
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
loss += mse_D_loss
if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss)
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss)
return_dict['D_hinge_gan_loss'] = hinge_D_loss
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss

View File

@ -29,9 +29,9 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
plt.close()
figures = {
name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap),
name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap),
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap),
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
name_prefix + "speech_comparison": fig_wave,
}
return figures