update vocoder loss implemenatations and fix MSEDLoss

This commit is contained in:
erogol 2020-06-12 19:32:49 +02:00
parent 0b78977662
commit 3eb730acf0
1 changed files with 14 additions and 5 deletions

View File

@ -88,7 +88,7 @@ class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake):
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape))
return loss_fake
@ -96,7 +96,8 @@ class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake):
loss_fake = torch.mean(F.relu(1. + score_fake))
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1. - score_fake))
return loss_fake
@ -107,10 +108,14 @@ class HingeGLoss(nn.Module):
class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """
def __init__(self,):
super(MSEDLoss, self).__init__()
self.loss_func = nn.MSELoss()
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
@ -126,11 +131,15 @@ class HingeDLoss(nn.Module):
class MelganFeatureLoss(nn.Module):
def __init__(self,):
super(MelganFeatureLoss, self).__init__()
self.loss_func = nn.L1Loss()
# pylint: disable=no-self-use
def forward(self, fake_feats, real_feats):
loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats):
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
loss_feats += self.loss_func(fake_feats, real_feats)
loss_feats /= len(fake_feats) + len(real_feats)
return loss_feats