update l1 and huber to mse loss

This commit is contained in:
Eren Gölge 2021-03-19 12:49:58 +01:00
parent 896d33ed49
commit 18cc7b95ec
1 changed files with 16 additions and 16 deletions

View File

@ -497,36 +497,36 @@ class AlignTTSLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.mdn_loss = MDNLoss()
self.l1 = L1LossMasked(c['loss_masking'])
self.spec_loss = MSELossMasked(False)
self.ssim = SSIMLoss()
self.huber = Huber()
self.dur_loss = MSELossMasked(False)
self.ssim_alpha = c.ssim_alpha
self.huber_alpha = c.huber_alpha
self.l1_alpha = c.l1_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.spec_loss_alpha = c.spec_loss_alpha
self.mdn_alpha = c.mdn_alpha
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase):
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
l1_loss, ssim_loss, huber_loss, mdn_loss = 0, 0, 0, 0
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
if phase == 0:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
elif phase == 1:
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 2:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 3:
huber_loss = self.huber(dur_output, dur_target, input_lens)
dur_loss = self.huber(dur_output, dur_target, input_lens)
else:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
spec_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = l1_alpha * l1_loss + ssim_alpha * ssim_loss + huber_alpha * huber_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss, 'mdn_loss': mdn_loss}
dur_loss = self.huber(dur_output, dur_target, input_lens)
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}
def _set_alpha(self, step, alpha_settings):
'''Set the loss alpha wrt number of steps.
@ -554,7 +554,7 @@ class AlignTTSLoss(nn.Module):
'''Set the alpha values for all the loss functions
'''
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
huber_alpha = self._set_alpha(step, self.huber_alpha)
l1_alpha = self._set_alpha(step, self.l1_alpha)
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
return ssim_alpha, huber_alpha, l1_alpha, mdn_alpha
return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha