diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 27c6e9e5..517eb533 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -462,13 +462,12 @@ class MDNLoss(nn.Module): class AlignTTSLoss(nn.Module): """Modified AlignTTS Loss. - Computes following losses + Computes - L1 and SSIM losses from output spectrograms. - Huber loss for duration predictor. - MDNLoss for Mixture of Density Network. - All the losses are aggregated by a weighted sum with the loss alphas. - Alphas can be scheduled based on number of steps. + All loss values are aggregated by a weighted sum of the alpha values. Args: c (dict): TTS model configuration. @@ -487,9 +486,9 @@ class AlignTTSLoss(nn.Module): self.mdn_alpha = c.mdn_alpha def forward( - self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase ): - ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) + # ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 if phase == 0: mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) @@ -507,36 +506,5 @@ class AlignTTSLoss(nn.Module): spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) - loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss + loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} - - @staticmethod - def _set_alpha(step, alpha_settings): - """Set the loss alpha wrt number of steps. - Return the corresponding value if no schedule is set. - - Example: - Setting a alpha schedule. - if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1. - if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant. - - Args: - step (int): number of training steps. - alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above. - """ - return_alpha = None - if isinstance(alpha_settings, list): - for key, alpha in alpha_settings: - if key < step: - return_alpha = alpha - elif isinstance(alpha_settings, (float, int)): - return_alpha = alpha_settings - return return_alpha - - def set_alphas(self, step): - """Set the alpha values for all the loss functions""" - ssim_alpha = self._set_alpha(step, self.ssim_alpha) - dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha) - spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha) - mdn_alpha = self._set_alpha(step, self.mdn_alpha) - return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha