Fix `FastPitchLoss`

This commit is contained in:
Eren Gölge 2021-07-12 12:30:27 +02:00
parent 94e8e0d416
commit db32162eae
1 changed files with 9 additions and 9 deletions

View File

@ -685,20 +685,20 @@ class FastPitchLoss(nn.Module):
input_lens,
):
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens)
dur_loss = self.dur_loss(dur_output[:, : ,None], dur_target[:, :, None], input_lens)
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
loss = (
self.l1_alpha * l1_loss
+ self.ssim_alpha * ssim_loss
+ self.huber_alpha * huber_loss
+ self.pitch_alpha * pitch_loss
self.spec_loss_alpha * spec_loss
+ self.ssim_loss_alpha * ssim_loss
+ self.dur_loss_alpha * dur_loss
+ self.pitch_loss_alpha * pitch_loss
)
return {
"loss": loss,
"loss_l1": l1_loss,
"loss_spec": spec_loss,
"loss_ssim": ssim_loss,
"loss_dur": huber_loss,
"loss_dur": dur_loss,
"loss_pitch": pitch_loss,
}