Turn more clear the VITS loss function

This commit is contained in:
Edresson 2021-11-22 08:48:56 -03:00 committed by Eren Gölge
parent 5fc127bb7a
commit 86b2536491
1 changed files with 6 additions and 3 deletions

View File

@ -586,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
l = kl / torch.sum(z_mask)
return l
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
def forward(
self,
waveform,
@ -632,9 +637,7 @@ class VitsGeneratorLoss(nn.Module):
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = (
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
)
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss += loss_se
return_dict["loss_spk_encoder"] = loss_se