diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e03cf084..09985c82 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -739,9 +739,9 @@ class ForwardTTSLoss(nn.Module): pitch_output, pitch_target, input_lens, - alignment_logprob=None, - alignment_hard=None, - alignment_soft=None, + aligner_logprob=None, + aligner_hard=None, + aligner_soft=None, binary_loss_weight=None, ): loss = 0 @@ -768,12 +768,12 @@ class ForwardTTSLoss(nn.Module): return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: - aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + aligner_loss = self.aligner_loss(aligner_logprob, input_lens, decoder_output_lens) loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss - if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: return_dict["loss_binary_alignment"] = ( @@ -784,3 +784,83 @@ class ForwardTTSLoss(nn.Module): return_dict["loss"] = loss return return_dict + + +class ForwardTTSE2ELoss(nn.Module): + def __init__(self, config): + super().__init__() + self.encoder_loss = ForwardTTSLoss(config) + # for generator losses + self.mel_loss_alpha = ( + config.mel_loss_alpha + ) # mel_loss over the encoder model output as opposed to the vocoder output + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + aligner_logprob=None, + aligner_hard=None, + aligner_soft=None, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + ): + loss_dict = self.encoder_loss( + decoder_output=decoder_output, + decoder_target=decoder_target, + decoder_output_lens=decoder_output_lens, + dur_output=dur_output, + dur_target=dur_target, + pitch_output=pitch_output, + pitch_target=pitch_target, + input_lens=input_lens, + aligner_logprob=aligner_logprob, + aligner_hard=aligner_hard, + aligner_soft=aligner_soft, + binary_loss_weight=binary_loss_weight, + ) + + # vocoder generator losses + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + return loss_dict