mirror of https://github.com/coqui-ai/TTS.git
update vocoder loss implemenatations and fix MSEDLoss
This commit is contained in:
parent
0b78977662
commit
3eb730acf0
|
@ -88,7 +88,7 @@ class MSEGLoss(nn.Module):
|
|||
""" Mean Squared Generator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape))
|
||||
return loss_fake
|
||||
|
||||
|
||||
|
@ -96,7 +96,8 @@ class HingeGLoss(nn.Module):
|
|||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1. - score_fake))
|
||||
return loss_fake
|
||||
|
||||
|
||||
|
@ -107,10 +108,14 @@ class HingeGLoss(nn.Module):
|
|||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEDLoss, self).__init__()
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
|
||||
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
@ -126,11 +131,15 @@ class HingeDLoss(nn.Module):
|
|||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(self,):
|
||||
super(MelganFeatureLoss, self).__init__()
|
||||
self.loss_func = nn.L1Loss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
|
||||
loss_feats += self.loss_func(fake_feats, real_feats)
|
||||
loss_feats /= len(fake_feats) + len(real_feats)
|
||||
return loss_feats
|
||||
|
||||
|
|
Loading…
Reference in New Issue