diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 94dea564..518b7ff3 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -497,36 +497,36 @@ class AlignTTSLoss(nn.Module): def __init__(self, c): super().__init__() self.mdn_loss = MDNLoss() - self.l1 = L1LossMasked(c['loss_masking']) + self.spec_loss = MSELossMasked(False) self.ssim = SSIMLoss() - self.huber = Huber() + self.dur_loss = MSELossMasked(False) self.ssim_alpha = c.ssim_alpha - self.huber_alpha = c.huber_alpha - self.l1_alpha = c.l1_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.spec_loss_alpha = c.spec_loss_alpha self.mdn_alpha = c.mdn_alpha def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase): - ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step) - l1_loss, ssim_loss, huber_loss, mdn_loss = 0, 0, 0, 0 + ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) + spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 if phase == 0: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) elif phase == 1: - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 2: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) elif phase == 3: - huber_loss = self.huber(dur_output, dur_target, input_lens) + dur_loss = self.huber(dur_output, dur_target, input_lens) else: mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - huber_loss = self.huber(dur_output, dur_target, input_lens) - loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss - return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss} + dur_loss = self.huber(dur_output, dur_target, input_lens) + loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss + return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss} def _set_alpha(self, step, alpha_settings): '''Set the loss alpha wrt number of steps. @@ -554,7 +554,7 @@ class AlignTTSLoss(nn.Module): '''Set the alpha values for all the loss functions ''' ssim_alpha = self._set_alpha(step, self.ssim_alpha) - huber_alpha = self._set_alpha(step, self.huber_alpha) - l1_alpha = self._set_alpha(step, self.l1_alpha) + dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha) + spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha) mdn_alpha = self._set_alpha(step, self.mdn_alpha) - return ssim_alpha, huber_alpha, l1_alpha, mdn_alpha + return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha