Update VITS loss

This commit is contained in:
Eren Gölge 2022-02-20 11:54:05 +01:00
parent c68962c574
commit 52a7896668
1 changed files with 18 additions and 24 deletions

View File

@ -587,13 +587,12 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
def forward(
self,
waveform,
waveform_hat,
mel_slice,
mel_slice_hat,
z_p,
logs_q,
m_p,
@ -609,8 +608,8 @@ class VitsGeneratorLoss(nn.Module):
):
"""
Shapes:
- waveform : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- mel_slice : :math:`[B, 1, T]`
- mel_slice_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]`
@ -623,30 +622,23 @@ class VitsGeneratorLoss(nn.Module):
loss = 0.0
return_dict = {}
z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_kl = self.kl_loss(
z_p=z_p,
logs_q=logs_q,
m_p=m_p,
logs_p=logs_p,
z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(
feats_real=feats_disc_real,
feats_generated=feats_disc_fake) * self.feat_loss_alpha
loss_kl = (
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
* self.kl_loss_alpha
)
loss_feat = (
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
)
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss += loss_se
loss = loss + loss_se
return_dict["loss_spk_encoder"] = loss_se
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
@ -675,16 +667,18 @@ class VitsDiscriminatorLoss(nn.Module):
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_disc_real_{i}"] = ldr
return return_dict