mirror of https://github.com/coqui-ai/TTS.git
Update VITS loss
This commit is contained in:
parent
c68962c574
commit
52a7896668
|
@ -587,13 +587,12 @@ class VitsGeneratorLoss(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
return l
|
||||
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
waveform_hat,
|
||||
mel_slice,
|
||||
mel_slice_hat,
|
||||
z_p,
|
||||
logs_q,
|
||||
m_p,
|
||||
|
@ -609,8 +608,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
):
|
||||
"""
|
||||
Shapes:
|
||||
- waveform : :math:`[B, 1, T]`
|
||||
- waveform_hat: :math:`[B, 1, T]`
|
||||
- mel_slice : :math:`[B, 1, T]`
|
||||
- mel_slice_hat: :math:`[B, 1, T]`
|
||||
- z_p: :math:`[B, C, T]`
|
||||
- logs_q: :math:`[B, C, T]`
|
||||
- m_p: :math:`[B, C, T]`
|
||||
|
@ -623,30 +622,23 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss = 0.0
|
||||
return_dict = {}
|
||||
z_mask = sequence_mask(z_len).float()
|
||||
# compute mel spectrograms from the waveforms
|
||||
mel = self.stft(waveform)
|
||||
mel_hat = self.stft(waveform_hat)
|
||||
|
||||
# compute losses
|
||||
loss_kl = self.kl_loss(
|
||||
z_p=z_p,
|
||||
logs_q=logs_q,
|
||||
m_p=m_p,
|
||||
logs_p=logs_p,
|
||||
z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_feat = self.feature_loss(
|
||||
feats_real=feats_disc_real,
|
||||
feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
loss_kl = (
|
||||
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
||||
* self.kl_loss_alpha
|
||||
)
|
||||
loss_feat = (
|
||||
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
)
|
||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||
loss += loss_se
|
||||
loss = loss + loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
@ -675,16 +667,18 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
loss += real_loss + fake_loss
|
||||
real_losses.append(real_loss.item())
|
||||
fake_losses.append(fake_loss.item())
|
||||
|
||||
return loss, real_losses, fake_losses
|
||||
|
||||
def forward(self, scores_disc_real, scores_disc_fake):
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
|
||||
for i, ldr in enumerate(loss_disc_real):
|
||||
return_dict[f"loss_disc_real_{i}"] = ldr
|
||||
return return_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue