mirror of https://github.com/coqui-ai/TTS.git
fixing naming convention in vocoder losses
This commit is contained in:
parent
f18f6e6d3e
commit
0de38c2617
|
@ -87,17 +87,17 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
|||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = F.mse_loss(score_fake, score_fake.new_ones(score_fake.shape))
|
||||
def forward(self, score_real):
|
||||
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
def forward(self, score_real):
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1. - score_fake))
|
||||
loss_fake = torch.mean(F.relu(1. - score_real))
|
||||
return loss_fake
|
||||
|
||||
|
||||
|
@ -172,7 +172,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
|||
if isinstance(scores_fake, list):
|
||||
# multi-scale loss
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss, fake_loss = loss_func(score_fake, score_real)
|
||||
total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
|
||||
loss += total_loss
|
||||
real_loss += real_loss
|
||||
fake_loss += fake_loss
|
||||
|
@ -286,14 +286,14 @@ class DiscriminatorLoss(nn.Module):
|
|||
return_dict = {}
|
||||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss)
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.mse_loss)
|
||||
return_dict['D_mse_gan_loss'] = mse_D_loss
|
||||
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
|
||||
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
|
||||
loss += mse_D_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss)
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake=scores_fake, scores_real=scores_real, self.hinge_loss)
|
||||
return_dict['D_hinge_gan_loss'] = hinge_D_loss
|
||||
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
|
||||
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
|
||||
|
|
|
@ -29,9 +29,9 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
|||
plt.close()
|
||||
|
||||
figures = {
|
||||
name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap),
|
||||
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
||||
name_prefix + "speech_comparison": fig_wave,
|
||||
}
|
||||
return figures
|
||||
|
|
Loading…
Reference in New Issue