diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 7e9304de..94dea564 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -3,7 +3,6 @@ import numpy as np import torch from torch import nn from torch.nn import functional -from torch.overrides import handle_torch_function from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.ssim import ssim @@ -507,13 +506,25 @@ class AlignTTSLoss(nn.Module): self.l1_alpha = c.l1_alpha self.mdn_alpha = c.mdn_alpha - def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step): - # flow loss - neg log likelihood + def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase): ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step) - mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens) - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) - ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - huber_loss = self.huber(dur_output, dur_target, input_lens) + l1_loss, ssim_loss, huber_loss, mdn_loss = 0, 0, 0, 0 + if phase == 0: + mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) + elif phase == 1: + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 2: + mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 3: + huber_loss = self.huber(dur_output, dur_target, input_lens) + else: + mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss}