mirror of https://github.com/coqui-ai/TTS.git
fix no loss masking loss computation
This commit is contained in:
parent
e8294cb9db
commit
8de7c13708
|
@ -2,6 +2,7 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from inspect import signature
|
||||
from torch.nn import functional
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
|
||||
|
@ -142,7 +143,11 @@ class DifferentailSpectralLoss(nn.Module):
|
|||
def forward(self, x, target, length):
|
||||
x_diff = x[:, 1:] - x[:, :-1]
|
||||
target_diff = target[:, 1:] - target[:, :-1]
|
||||
return self.loss_func(x_diff, target_diff, length-1)
|
||||
if len(signature(self.loss_func).parameters) > 2:
|
||||
return self.loss_func(x_diff, target_diff, length-1)
|
||||
else:
|
||||
# if loss masking is not enabled
|
||||
return self.loss_func(x_diff, target_diff)
|
||||
|
||||
|
||||
class GuidedAttentionLoss(torch.nn.Module):
|
||||
|
@ -262,8 +267,11 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
# double decoder consistency loss (if enabled)
|
||||
if self.config.double_decoder_consistency:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
|
||||
output_lens)
|
||||
if self.config.loss_masking:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
|
||||
output_lens)
|
||||
else:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
|
||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)
|
||||
|
|
Loading…
Reference in New Issue