diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 518b7ff3..b506b33f 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -512,19 +512,19 @@ class AlignTTSLoss(nn.Module): if phase == 0: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) elif phase == 1: - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 2: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 3: - dur_loss = self.huber(dur_output, dur_target, input_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) else: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - dur_loss = self.huber(dur_output, dur_target, input_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}