mirror of https://github.com/coqui-ai/TTS.git
Turn more clear the VITS loss function
This commit is contained in:
parent
6fc3b9e679
commit
8c22d5ac49
|
@ -586,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
l = kl / torch.sum(z_mask)
|
||||
return l
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
return l
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
|
@ -632,9 +637,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = (
|
||||
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
||||
)
|
||||
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||
loss += loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
|
|
Loading…
Reference in New Issue