mirror of https://github.com/coqui-ai/TTS.git
fix no loss masking loss computation
This commit is contained in:
parent
e8294cb9db
commit
8de7c13708
|
@ -2,6 +2,7 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from inspect import signature
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
@ -142,7 +143,11 @@ class DifferentailSpectralLoss(nn.Module):
|
||||||
def forward(self, x, target, length):
|
def forward(self, x, target, length):
|
||||||
x_diff = x[:, 1:] - x[:, :-1]
|
x_diff = x[:, 1:] - x[:, :-1]
|
||||||
target_diff = target[:, 1:] - target[:, :-1]
|
target_diff = target[:, 1:] - target[:, :-1]
|
||||||
|
if len(signature(self.loss_func).parameters) > 2:
|
||||||
return self.loss_func(x_diff, target_diff, length-1)
|
return self.loss_func(x_diff, target_diff, length-1)
|
||||||
|
else:
|
||||||
|
# if loss masking is not enabled
|
||||||
|
return self.loss_func(x_diff, target_diff)
|
||||||
|
|
||||||
|
|
||||||
class GuidedAttentionLoss(torch.nn.Module):
|
class GuidedAttentionLoss(torch.nn.Module):
|
||||||
|
@ -262,8 +267,11 @@ class TacotronLoss(torch.nn.Module):
|
||||||
|
|
||||||
# double decoder consistency loss (if enabled)
|
# double decoder consistency loss (if enabled)
|
||||||
if self.config.double_decoder_consistency:
|
if self.config.double_decoder_consistency:
|
||||||
|
if self.config.loss_masking:
|
||||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
|
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
|
||||||
output_lens)
|
output_lens)
|
||||||
|
else:
|
||||||
|
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
|
||||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||||
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)
|
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)
|
||||||
|
|
Loading…
Reference in New Issue