fix no loss masking loss computation

This commit is contained in:
erogol 2020-10-19 15:47:12 +02:00
parent e8294cb9db
commit 8de7c13708
1 changed files with 11 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import math
import numpy as np
import torch
from torch import nn
from inspect import signature
from torch.nn import functional
from TTS.tts.utils.generic_utils import sequence_mask
@ -142,7 +143,11 @@ class DifferentailSpectralLoss(nn.Module):
def forward(self, x, target, length):
x_diff = x[:, 1:] - x[:, :-1]
target_diff = target[:, 1:] - target[:, :-1]
return self.loss_func(x_diff, target_diff, length-1)
if len(signature(self.loss_func).parameters) > 2:
return self.loss_func(x_diff, target_diff, length-1)
else:
# if loss masking is not enabled
return self.loss_func(x_diff, target_diff)
class GuidedAttentionLoss(torch.nn.Module):
@ -262,8 +267,11 @@ class TacotronLoss(torch.nn.Module):
# double decoder consistency loss (if enabled)
if self.config.double_decoder_consistency:
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
output_lens)
if self.config.loss_masking:
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
output_lens)
else:
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)