From db32162eaead6ae5f6ee9f34bafbdfebe83acb44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:30:27 +0200 Subject: [PATCH] Fix `FastPitchLoss` --- TTS/tts/layers/losses.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index b4866df1..8a50c811 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -685,20 +685,20 @@ class FastPitchLoss(nn.Module): input_lens, ): - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - huber_loss = self.huber(dur_output, dur_target, input_lens) - pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens) + dur_loss = self.dur_loss(dur_output[:, : ,None], dur_target[:, :, None], input_lens) + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) loss = ( - self.l1_alpha * l1_loss - + self.ssim_alpha * ssim_loss - + self.huber_alpha * huber_loss - + self.pitch_alpha * pitch_loss + self.spec_loss_alpha * spec_loss + + self.ssim_loss_alpha * ssim_loss + + self.dur_loss_alpha * dur_loss + + self.pitch_loss_alpha * pitch_loss ) return { "loss": loss, - "loss_l1": l1_loss, + "loss_spec": spec_loss, "loss_ssim": ssim_loss, - "loss_dur": huber_loss, + "loss_dur": dur_loss, "loss_pitch": pitch_loss, }