mirror of https://github.com/coqui-ai/TTS.git
Turn more clear the VITS loss function
This commit is contained in:
parent
6fc3b9e679
commit
8c22d5ac49
|
@ -586,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
l = kl / torch.sum(z_mask)
|
l = kl / torch.sum(z_mask)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||||
|
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||||
|
return l
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
waveform,
|
waveform,
|
||||||
|
@ -632,9 +637,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
|
||||||
if use_speaker_encoder_as_loss:
|
if use_speaker_encoder_as_loss:
|
||||||
loss_se = (
|
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||||
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
|
||||||
)
|
|
||||||
loss += loss_se
|
loss += loss_se
|
||||||
return_dict["loss_spk_encoder"] = loss_se
|
return_dict["loss_spk_encoder"] = loss_se
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue