mirror of https://github.com/coqui-ai/TTS.git
Update FastPitchLoss
This commit is contained in:
parent
b81560607b
commit
fac9dbe661
|
@ -684,21 +684,28 @@ class FastPitchLoss(nn.Module):
|
||||||
pitch_target,
|
pitch_target,
|
||||||
input_lens,
|
input_lens,
|
||||||
):
|
):
|
||||||
|
loss = 0
|
||||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
return_dict = {}
|
||||||
|
if self.ssim_loss_alpha > 0:
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
dur_loss = self.dur_loss(dur_output[:, :, None], dur_target[:, :, None], input_lens)
|
loss += self.ssim_loss_alpha * ssim_loss
|
||||||
|
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
|
||||||
|
|
||||||
|
if self.spec_loss_alpha > 0:
|
||||||
|
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
loss += self.spec_loss_alpha * spec_loss
|
||||||
|
return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
|
||||||
|
|
||||||
|
if self.dur_loss_alpha > 0:
|
||||||
|
log_dur_tgt = torch.log(dur_target.float() + 1)
|
||||||
|
dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
|
||||||
|
loss += self.dur_loss_alpha * dur_loss
|
||||||
|
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
|
||||||
|
|
||||||
|
if self.pitch_loss_alpha > 0:
|
||||||
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
|
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
|
||||||
loss = (
|
loss += self.pitch_loss_alpha * pitch_loss
|
||||||
self.spec_loss_alpha * spec_loss
|
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||||
+ self.ssim_loss_alpha * ssim_loss
|
|
||||||
+ self.dur_loss_alpha * dur_loss
|
return_dict["loss"] = loss
|
||||||
+ self.pitch_loss_alpha * pitch_loss
|
return return_dict
|
||||||
)
|
|
||||||
return {
|
|
||||||
"loss": loss,
|
|
||||||
"loss_spec": spec_loss,
|
|
||||||
"loss_ssim": ssim_loss,
|
|
||||||
"loss_dur": dur_loss,
|
|
||||||
"loss_pitch": pitch_loss,
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue