mirror of https://github.com/coqui-ai/TTS.git
fixing naming convention in vocoder losses
This commit is contained in:
parent
f18f6e6d3e
commit
0de38c2617
vocoder
|
@ -87,17 +87,17 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||||
class MSEGLoss(nn.Module):
|
class MSEGLoss(nn.Module):
|
||||||
""" Mean Squared Generator Loss """
|
""" Mean Squared Generator Loss """
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def forward(self, score_fake):
|
def forward(self, score_real):
|
||||||
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape))
|
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||||
return loss_fake
|
return loss_fake
|
||||||
|
|
||||||
|
|
||||||
class HingeGLoss(nn.Module):
|
class HingeGLoss(nn.Module):
|
||||||
""" Hinge Discriminator Loss """
|
""" Hinge Discriminator Loss """
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def forward(self, score_fake):
|
def forward(self, score_real):
|
||||||
# TODO: this might be wrong
|
# TODO: this might be wrong
|
||||||
loss_fake = torch.mean(F.relu(1. - score_fake))
|
loss_fake = torch.mean(F.relu(1. - score_real))
|
||||||
return loss_fake
|
return loss_fake
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||||
if isinstance(scores_fake, list):
|
if isinstance(scores_fake, list):
|
||||||
# multi-scale loss
|
# multi-scale loss
|
||||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||||
total_loss, real_loss, fake_loss = loss_func(score_fake, score_real)
|
total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
|
||||||
loss += total_loss
|
loss += total_loss
|
||||||
real_loss += real_loss
|
real_loss += real_loss
|
||||||
fake_loss += fake_loss
|
fake_loss += fake_loss
|
||||||
|
@ -286,14 +286,14 @@ class DiscriminatorLoss(nn.Module):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
|
||||||
if self.use_mse_gan_loss:
|
if self.use_mse_gan_loss:
|
||||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss)
|
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss)
|
||||||
return_dict['D_mse_gan_loss'] = mse_D_loss
|
return_dict['D_mse_gan_loss'] = mse_D_loss
|
||||||
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
|
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
|
||||||
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
|
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
|
||||||
loss += mse_D_loss
|
loss += mse_D_loss
|
||||||
|
|
||||||
if self.use_hinge_gan_loss:
|
if self.use_hinge_gan_loss:
|
||||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss)
|
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss)
|
||||||
return_dict['D_hinge_gan_loss'] = hinge_D_loss
|
return_dict['D_hinge_gan_loss'] = hinge_D_loss
|
||||||
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
|
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
|
||||||
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
|
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
|
||||||
|
|
|
@ -29,9 +29,9 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
figures = {
|
figures = {
|
||||||
name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap),
|
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
||||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap),
|
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
||||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap),
|
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
||||||
name_prefix + "speech_comparison": fig_wave,
|
name_prefix + "speech_comparison": fig_wave,
|
||||||
}
|
}
|
||||||
return figures
|
return figures
|
||||||
|
|
Loading…
Reference in New Issue