Update FastPitchLoss

This commit is contained in:
Eren Gölge 2021-07-22 14:20:54 +02:00
parent b81560607b
commit fac9dbe661
1 changed files with 24 additions and 17 deletions

View File

@ -684,21 +684,28 @@ class FastPitchLoss(nn.Module):
pitch_target, pitch_target,
input_lens, input_lens,
): ):
loss = 0
return_dict = {}
if self.ssim_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
loss += self.ssim_loss_alpha * ssim_loss
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) if self.spec_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output[:, :, None], dur_target[:, :, None], input_lens) loss += self.spec_loss_alpha * spec_loss
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
loss = (
self.spec_loss_alpha * spec_loss if self.dur_loss_alpha > 0:
+ self.ssim_loss_alpha * ssim_loss log_dur_tgt = torch.log(dur_target.float() + 1)
+ self.dur_loss_alpha * dur_loss dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
+ self.pitch_loss_alpha * pitch_loss loss += self.dur_loss_alpha * dur_loss
) return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
return {
"loss": loss, if self.pitch_loss_alpha > 0:
"loss_spec": spec_loss, pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
"loss_ssim": ssim_loss, loss += self.pitch_loss_alpha * pitch_loss
"loss_dur": dur_loss, return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
"loss_pitch": pitch_loss,
} return_dict["loss"] = loss
return return_dict