mirror of https://github.com/coqui-ai/TTS.git
aligntts loss
This commit is contained in:
parent
a831468cab
commit
aa29f5b199
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional
|
||||
from torch.overrides import handle_torch_function
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
|
||||
|
@ -440,5 +441,109 @@ class SpeedySpeechLoss(nn.Module):
|
|||
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||
loss = l1_loss + ssim_loss + huber_loss
|
||||
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
|
||||
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
|
||||
|
||||
|
||||
def mse_loss_custom(input, target):
|
||||
"""MSE loss using the torch back-end without reduction.
|
||||
It uses less VRAM than the raw code"""
|
||||
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
|
||||
return torch._C._nn.mse_loss(expanded_input, expanded_target, 0)
|
||||
|
||||
|
||||
class MDNLoss(nn.Module):
|
||||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, mu, log_sigma, logp_max_path, melspec, text_lengths, mel_lengths):
|
||||
'''
|
||||
Shapes:
|
||||
mu: [B, D, T]
|
||||
log_sigma: [B, D, T]
|
||||
mel_spec: [B, D, T]
|
||||
'''
|
||||
B, D, L = mu.size()
|
||||
T = melspec.size(2)
|
||||
x = melspec.transpose(1,2).unsqueeze(1) # [B, 1, T1, D]
|
||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
|
||||
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))
|
||||
log_alpha = mu.new_ones(B, L, T)*(-1e4)
|
||||
log_alpha[:, 0, 0] = log_prob_matrix[:, 0, 0]
|
||||
for t in range(1, T):
|
||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + log_prob_matrix[:, :, t]
|
||||
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
||||
mdn_loss = -alpha_last.mean() / L
|
||||
return mdn_loss#, log_prob_matrix
|
||||
|
||||
|
||||
class AlignTTSLoss(nn.Module):
|
||||
"""Modified AlignTTS Loss.
|
||||
Computes following losses
|
||||
- L1 and SSIM losses from output spectrograms.
|
||||
- Huber loss for duration predictor.
|
||||
- MDNLoss for Mixture of Density Network.
|
||||
|
||||
All the losses are aggregated by a weighted sum with the loss alphas.
|
||||
Alphas can be scheduled based on number of steps.
|
||||
|
||||
Args:
|
||||
c (dict): TTS model configuration.
|
||||
"""
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.mdn_loss = MDNLoss()
|
||||
self.l1 = L1LossMasked(c['loss_masking'])
|
||||
self.ssim = SSIMLoss()
|
||||
self.huber = Huber()
|
||||
|
||||
self.ssim_alpha = c.ssim_alpha
|
||||
self.huber_alpha = c.huber_alpha
|
||||
self.l1_alpha = c.l1_alpha
|
||||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step):
|
||||
# flow loss - neg log likelihood
|
||||
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
|
||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens)
|
||||
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||
loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss
|
||||
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss}
|
||||
|
||||
def _set_alpha(self, step, alpha_settings):
|
||||
'''Set the loss alpha wrt number of steps.
|
||||
Return the corresponding value if no schedule is set.
|
||||
|
||||
Example:
|
||||
Setting a alpha schedule.
|
||||
if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1.
|
||||
if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant.
|
||||
|
||||
Args:
|
||||
step (int): number of training steps.
|
||||
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
|
||||
'''
|
||||
return_alpha = None
|
||||
if isinstance(alpha_settings, list):
|
||||
for key, alpha in alpha_settings:
|
||||
if key < step:
|
||||
return_alpha = alpha
|
||||
elif isinstance(alpha_settings, float) or isinstance(alpha_settings, int):
|
||||
return_alpha = alpha_settings
|
||||
return return_alpha
|
||||
|
||||
def set_alphas(self, step):
|
||||
'''Set the alpha values for all the loss functions
|
||||
'''
|
||||
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
|
||||
huber_alpha = self._set_alpha(step, self.huber_alpha)
|
||||
l1_alpha = self._set_alpha(step, self.l1_alpha)
|
||||
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
|
||||
return ssim_alpha, huber_alpha, l1_alpha, mdn_alpha
|
||||
|
|
Loading…
Reference in New Issue