fixing naming convention in vocoder losses

This commit is contained in:
erogol 2020-06-16 12:34:10 +02:00
parent f18f6e6d3e
commit 0de38c2617
2 changed files with 10 additions and 10 deletions
vocoder

View File

@ -87,17 +87,17 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module): class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """ """ Mean Squared Generator Loss """
# pylint: disable=no-self-use # pylint: disable=no-self-use
def forward(self, score_fake): def forward(self, score_real):
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape)) loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
return loss_fake return loss_fake
class HingeGLoss(nn.Module): class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """ """ Hinge Discriminator Loss """
# pylint: disable=no-self-use # pylint: disable=no-self-use
def forward(self, score_fake): def forward(self, score_real):
# TODO: this might be wrong # TODO: this might be wrong
loss_fake = torch.mean(F.relu(1. - score_fake)) loss_fake = torch.mean(F.relu(1. - score_real))
return loss_fake return loss_fake
@ -172,7 +172,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
if isinstance(scores_fake, list): if isinstance(scores_fake, list):
# multi-scale loss # multi-scale loss
for score_fake, score_real in zip(scores_fake, scores_real): for score_fake, score_real in zip(scores_fake, scores_real):
total_loss, real_loss, fake_loss = loss_func(score_fake, score_real) total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
loss += total_loss loss += total_loss
real_loss += real_loss real_loss += real_loss
fake_loss += fake_loss fake_loss += fake_loss
@ -286,14 +286,14 @@ class DiscriminatorLoss(nn.Module):
return_dict = {} return_dict = {}
if self.use_mse_gan_loss: if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss) mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss)
return_dict['D_mse_gan_loss'] = mse_D_loss return_dict['D_mse_gan_loss'] = mse_D_loss
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
loss += mse_D_loss loss += mse_D_loss
if self.use_hinge_gan_loss: if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss) hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss)
return_dict['D_hinge_gan_loss'] = hinge_D_loss return_dict['D_hinge_gan_loss'] = hinge_D_loss
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss

View File

@ -29,9 +29,9 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
plt.close() plt.close()
figures = { figures = {
name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap), name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap), name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap), name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
name_prefix + "speech_comparison": fig_wave, name_prefix + "speech_comparison": fig_wave,
} }
return figures return figures