From fac9dbe6619f87444a540b943b650ce96b7fab6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:20:54 +0200 Subject: [PATCH] Update FastPitchLoss --- TTS/tts/layers/losses.py | 41 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index efe64e2b..fefaec9a 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -684,21 +684,28 @@ class FastPitchLoss(nn.Module): pitch_target, input_lens, ): + loss = 0 + return_dict = {} + if self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss += self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss - spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) - ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - dur_loss = self.dur_loss(dur_output[:, :, None], dur_target[:, :, None], input_lens) - pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) - loss = ( - self.spec_loss_alpha * spec_loss - + self.ssim_loss_alpha * ssim_loss - + self.dur_loss_alpha * dur_loss - + self.pitch_loss_alpha * pitch_loss - ) - return { - "loss": loss, - "loss_spec": spec_loss, - "loss_ssim": ssim_loss, - "loss_dur": dur_loss, - "loss_pitch": pitch_loss, - } + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss += self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss += self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss += self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + return_dict["loss"] = loss + return return_dict