fix no loss masking loss computation

This commit is contained in:
erogol 2020-10-19 15:47:12 +02:00
parent e8294cb9db
commit 8de7c13708
1 changed files with 11 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from inspect import signature
from torch.nn import functional from torch.nn import functional
from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.generic_utils import sequence_mask
@ -142,7 +143,11 @@ class DifferentailSpectralLoss(nn.Module):
def forward(self, x, target, length): def forward(self, x, target, length):
x_diff = x[:, 1:] - x[:, :-1] x_diff = x[:, 1:] - x[:, :-1]
target_diff = target[:, 1:] - target[:, :-1] target_diff = target[:, 1:] - target[:, :-1]
if len(signature(self.loss_func).parameters) > 2:
return self.loss_func(x_diff, target_diff, length-1) return self.loss_func(x_diff, target_diff, length-1)
else:
# if loss masking is not enabled
return self.loss_func(x_diff, target_diff)
class GuidedAttentionLoss(torch.nn.Module): class GuidedAttentionLoss(torch.nn.Module):
@ -262,8 +267,11 @@ class TacotronLoss(torch.nn.Module):
# double decoder consistency loss (if enabled) # double decoder consistency loss (if enabled)
if self.config.double_decoder_consistency: if self.config.double_decoder_consistency:
if self.config.loss_masking:
decoder_b_loss = self.criterion(decoder_b_output, mel_input, decoder_b_loss = self.criterion(decoder_b_output, mel_input,
output_lens) output_lens)
else:
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)