mirror of https://github.com/coqui-ai/TTS.git
update losses to hande alingtts phases
This commit is contained in:
parent
aec0b78aff
commit
896d33ed49
|
@ -3,7 +3,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from torch.overrides import handle_torch_function
|
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
|
|
||||||
|
@ -507,10 +506,22 @@ class AlignTTSLoss(nn.Module):
|
||||||
self.l1_alpha = c.l1_alpha
|
self.l1_alpha = c.l1_alpha
|
||||||
self.mdn_alpha = c.mdn_alpha
|
self.mdn_alpha = c.mdn_alpha
|
||||||
|
|
||||||
def forward(self, mu, log_sigma, logp_max_path, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step):
|
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase):
|
||||||
# flow loss - neg log likelihood
|
|
||||||
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
|
ssim_alpha, huber_alpha, l1_alpha, mdn_alpha = self.set_alphas(step)
|
||||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp_max_path, decoder_target, input_lens, decoder_output_lens)
|
l1_loss, ssim_loss, huber_loss, mdn_loss = 0, 0, 0, 0
|
||||||
|
if phase == 0:
|
||||||
|
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||||
|
elif phase == 1:
|
||||||
|
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
elif phase == 2:
|
||||||
|
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||||
|
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
elif phase == 3:
|
||||||
|
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||||
|
else:
|
||||||
|
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||||
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||||
|
|
Loading…
Reference in New Issue