diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index f07851ac..8256c0f7 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -2,6 +2,7 @@ import math import numpy as np import torch from torch import nn +from inspect import signature from torch.nn import functional from TTS.tts.utils.generic_utils import sequence_mask @@ -142,7 +143,11 @@ class DifferentailSpectralLoss(nn.Module): def forward(self, x, target, length): x_diff = x[:, 1:] - x[:, :-1] target_diff = target[:, 1:] - target[:, :-1] - return self.loss_func(x_diff, target_diff, length-1) + if len(signature(self.loss_func).parameters) > 2: + return self.loss_func(x_diff, target_diff, length-1) + else: + # if loss masking is not enabled + return self.loss_func(x_diff, target_diff) class GuidedAttentionLoss(torch.nn.Module): @@ -262,8 +267,11 @@ class TacotronLoss(torch.nn.Module): # double decoder consistency loss (if enabled) if self.config.double_decoder_consistency: - decoder_b_loss = self.criterion(decoder_b_output, mel_input, - output_lens) + if self.config.loss_masking: + decoder_b_loss = self.criterion(decoder_b_output, mel_input, + output_lens) + else: + decoder_b_loss = self.criterion(decoder_b_output, mel_input) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)