Add FastPitchLoss

This commit is contained in:
Eren Gölge 2021-07-08 01:28:41 +02:00
parent fba257104d
commit c8d999b010
1 changed files with 44 additions and 0 deletions

View File

@ -658,3 +658,47 @@ class VitsDiscriminatorLoss(nn.Module):
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
return return_dict
class FastPitchLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.spec_loss = MSELossMasked(False)
self.ssim = SSIMLoss()
self.dur_loss = MSELossMasked(False)
self.pitch_loss = MSELossMasked(False)
self.spec_loss_alpha = c.spec_loss_alpha
self.ssim_loss_alpha = c.ssim_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
def forward(
self,
decoder_output,
decoder_target,
decoder_output_lens,
dur_output,
dur_target,
pitch_output,
pitch_target,
input_lens,
):
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens)
loss = (
self.l1_alpha * l1_loss
+ self.ssim_alpha * ssim_loss
+ self.huber_alpha * huber_loss
+ self.pitch_alpha * pitch_loss
)
return {
"loss": loss,
"loss_l1": l1_loss,
"loss_ssim": ssim_loss,
"loss_dur": huber_loss,
"loss_pitch": pitch_loss,
}