From d542a5081838a8d55aec2eba7b6eee32836e052c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 20 Mar 2021 16:06:19 +0100 Subject: [PATCH] fix losses for alignTTS --- TTS/tts/layers/losses.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 518b7ff3..b506b33f 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -512,19 +512,19 @@ class AlignTTSLoss(nn.Module): if phase == 0: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) elif phase == 1: - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 2: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 3: - dur_loss = self.huber(dur_output, dur_target, input_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) else: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - dur_loss = self.huber(dur_output, dur_target, input_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}