update losses to hande alingtts phases

This commit is contained in:
Eren Gölge 2021-03-17 17:19:08 +01:00
parent aec0b78aff
commit 896d33ed49
1 changed files with 18 additions and 7 deletions

View File

@ -3,7 +3,6 @@ import numpy as np
import torch
from torch import nn
from torch.nn import functional
from torch.overrides import handle_torch_function
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.utils.ssim import ssim
@ -507,13 +506,25 @@ class AlignTTSLoss(nn.Module):
self.l1_alpha = c.l1_alpha
self.mdn_alpha = c.mdn_alpha
def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step):
# flow loss - neg log likelihood
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase):
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
l1_loss, ssim_loss, huber_loss, mdn_loss = 0, 0, 0, 0
if phase == 0:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
elif phase == 1:
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 2:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 3:
huber_loss = self.huber(dur_output, dur_target, input_lens)
else:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss}