aligntts loss

This commit is contained in:
Eren Gölge 2021-03-03 15:41:59 +01:00 committed by Eren Gölge
parent a831468cab
commit aa29f5b199
1 changed files with 106 additions and 1 deletions

View File

@ -3,6 +3,7 @@ import numpy as np
import torch
from torch import nn
from torch.nn import functional
from torch.overrides import handle_torch_function
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.utils.ssim import ssim
@ -440,5 +441,109 @@ class SpeedySpeechLoss(nn.Module):
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = l1_loss + ssim_loss + huber_loss
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
def mse_loss_custom(input, target):
"""MSE loss using the torch back-end without reduction.
It uses less VRAM than the raw code"""
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.mse_loss(expanded_input, expanded_target, 0)
class MDNLoss(nn.Module):
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
"""
def __init__(self):
super().__init__()
def forward(self, mu, log_sigma, logp_max_path, melspec, text_lengths, mel_lengths):
'''
Shapes:
mu: [B, D, T]
log_sigma: [B, D, T]
mel_spec: [B, D, T]
'''
B, D, L = mu.size()
T = melspec.size(2)
x = melspec.transpose(1,2).unsqueeze(1) # [B, 1, T1, D]
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))
log_alpha = mu.new_ones(B, L, T)*(-1e4)
log_alpha[:, 0, 0] = log_prob_matrix[:, 0, 0]
for t in range(1, T):
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + log_prob_matrix[:, :, t]
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
mdn_loss = -alpha_last.mean() / L
return mdn_loss#, log_prob_matrix
class AlignTTSLoss(nn.Module):
"""Modified AlignTTS Loss.
Computes following losses
- L1 and SSIM losses from output spectrograms.
- Huber loss for duration predictor.
- MDNLoss for Mixture of Density Network.
All the losses are aggregated by a weighted sum with the loss alphas.
Alphas can be scheduled based on number of steps.
Args:
c (dict): TTS model configuration.
"""
def __init__(self, c):
super().__init__()
self.mdn_loss = MDNLoss()
self.l1 = L1LossMasked(c['loss_masking'])
self.ssim = SSIMLoss()
self.huber = Huber()
self.ssim_alpha = c.ssim_alpha
self.huber_alpha = c.huber_alpha
self.l1_alpha = c.l1_alpha
self.mdn_alpha = c.mdn_alpha
def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step):
# flow loss - neg log likelihood
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss}
def _set_alpha(self, step, alpha_settings):
'''Set the loss alpha wrt number of steps.
Return the corresponding value if no schedule is set.
Example:
Setting a alpha schedule.
if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1.
if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant.
Args:
step (int): number of training steps.
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
'''
return_alpha = None
if isinstance(alpha_settings, list):
for key, alpha in alpha_settings:
if key < step:
return_alpha = alpha
elif isinstance(alpha_settings, float) or isinstance(alpha_settings, int):
return_alpha = alpha_settings
return return_alpha
def set_alphas(self, step):
'''Set the alpha values for all the loss functions
'''
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
huber_alpha = self._set_alpha(step, self.huber_alpha)
l1_alpha = self._set_alpha(step, self.l1_alpha)
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
return ssim_alpha, huber_alpha, l1_alpha, mdn_alpha