diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ce4ada9..b4866df1 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -658,3 +658,47 @@ class VitsDiscriminatorLoss(nn.Module): loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss return return_dict + + +class FastPitchLoss(nn.Module): + def __init__(self, c): + super().__init__() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + self.pitch_loss = MSELossMasked(False) + + self.spec_loss_alpha = c.spec_loss_alpha + self.ssim_loss_alpha = c.ssim_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.pitch_loss_alpha = c.pitch_loss_alpha + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + ): + + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) + pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens) + loss = ( + self.l1_alpha * l1_loss + + self.ssim_alpha * ssim_loss + + self.huber_alpha * huber_loss + + self.pitch_alpha * pitch_loss + ) + return { + "loss": loss, + "loss_l1": l1_loss, + "loss_ssim": ssim_loss, + "loss_dur": huber_loss, + "loss_pitch": pitch_loss, + }