diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 213970a7..7e9304de 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -3,6 +3,7 @@ import numpy as np import torch from torch import nn from torch.nn import functional +from torch.overrides import handle_torch_function from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.ssim import ssim @@ -440,5 +441,109 @@ class SpeedySpeechLoss(nn.Module): l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) huber_loss = self.huber(dur_output, dur_target, input_lens) - loss = l1_loss + ssim_loss + huber_loss + loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss} + + +def mse_loss_custom(input, target): + """MSE loss using the torch back-end without reduction. + It uses less VRAM than the raw code""" + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.mse_loss(expanded_input, expanded_target, 0) + + +class MDNLoss(nn.Module): + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf. + """ + def __init__(self): + super().__init__() + + def forward(self, mu, log_sigma, logp_max_path, melspec, text_lengths, mel_lengths): + ''' + Shapes: + mu: [B, D, T] + log_sigma: [B, D, T] + mel_spec: [B, D, T] + ''' + B, D, L = mu.size() + T = melspec.size(2) + x = melspec.transpose(1,2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T + log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) + log_alpha = mu.new_ones(B, L, T)*(-1e4) + log_alpha[:, 0, 0] = log_prob_matrix[:, 0, 0] + for t in range(1, T): + prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1) + log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + log_prob_matrix[:, :, t] + alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1] + mdn_loss = -alpha_last.mean() / L + return mdn_loss#, log_prob_matrix + + +class AlignTTSLoss(nn.Module): + """Modified AlignTTS Loss. + Computes following losses + - L1 and SSIM losses from output spectrograms. + - Huber loss for duration predictor. + - MDNLoss for Mixture of Density Network. + + All the losses are aggregated by a weighted sum with the loss alphas. + Alphas can be scheduled based on number of steps. + + Args: + c (dict): TTS model configuration. + """ + def __init__(self, c): + super().__init__() + self.mdn_loss = MDNLoss() + self.l1 = L1LossMasked(c['loss_masking']) + self.ssim = SSIMLoss() + self.huber = Huber() + + self.ssim_alpha = c.ssim_alpha + self.huber_alpha = c.huber_alpha + self.l1_alpha = c.l1_alpha + self.mdn_alpha = c.mdn_alpha + + def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step): + # flow loss - neg log likelihood + ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step) + mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens) + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) + loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss + return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss} + + def _set_alpha(self, step, alpha_settings): + '''Set the loss alpha wrt number of steps. + Return the corresponding value if no schedule is set. + + Example: + Setting a alpha schedule. + if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1. + if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant. + + Args: + step (int): number of training steps. + alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above. + ''' + return_alpha = None + if isinstance(alpha_settings, list): + for key, alpha in alpha_settings: + if key < step: + return_alpha = alpha + elif isinstance(alpha_settings, float) or isinstance(alpha_settings, int): + return_alpha = alpha_settings + return return_alpha + + def set_alphas(self, step): + '''Set the alpha values for all the loss functions + ''' + ssim_alpha = self._set_alpha(step, self.ssim_alpha) + huber_alpha = self._set_alpha(step, self.huber_alpha) + l1_alpha = self._set_alpha(step, self.l1_alpha) + mdn_alpha = self._set_alpha(step, self.mdn_alpha) + return ssim_alpha, huber_alpha, l1_alpha, mdn_alpha