diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index acf750a0..9c219998 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -586,6 +586,11 @@ class VitsGeneratorLoss(nn.Module): l = kl / torch.sum(z_mask) return l + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + return l + def forward( self, waveform, @@ -632,9 +637,7 @@ class VitsGeneratorLoss(nn.Module): loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: - loss_se = ( - -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha - ) + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha loss += loss_se return_dict["loss_spk_encoder"] = loss_se