From c8d999b0105281b08b4f6e14b7edea6565310d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 8 Jul 2021 01:28:41 +0200 Subject: [PATCH] Add FastPitchLoss --- TTS/tts/layers/losses.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ce4ada9..b4866df1 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -658,3 +658,47 @@ class VitsDiscriminatorLoss(nn.Module): loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss return return_dict + + +class FastPitchLoss(nn.Module): + def __init__(self, c): + super().__init__() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + self.pitch_loss = MSELossMasked(False) + + self.spec_loss_alpha = c.spec_loss_alpha + self.ssim_loss_alpha = c.ssim_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.pitch_loss_alpha = c.pitch_loss_alpha + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + ): + + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) + pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens) + loss = ( + self.l1_alpha * l1_loss + + self.ssim_alpha * ssim_loss + + self.huber_alpha * huber_loss + + self.pitch_alpha * pitch_loss + ) + return { + "loss": loss, + "loss_l1": l1_loss, + "loss_ssim": ssim_loss, + "loss_dur": huber_loss, + "loss_pitch": pitch_loss, + }