mirror of https://github.com/coqui-ai/TTS.git
Implement ForwardTTSE2E Loss
This commit is contained in:
parent
95b52a65af
commit
29216ff907
|
@ -739,9 +739,9 @@ class ForwardTTSLoss(nn.Module):
|
||||||
pitch_output,
|
pitch_output,
|
||||||
pitch_target,
|
pitch_target,
|
||||||
input_lens,
|
input_lens,
|
||||||
alignment_logprob=None,
|
aligner_logprob=None,
|
||||||
alignment_hard=None,
|
aligner_hard=None,
|
||||||
alignment_soft=None,
|
aligner_soft=None,
|
||||||
binary_loss_weight=None,
|
binary_loss_weight=None,
|
||||||
):
|
):
|
||||||
loss = 0
|
loss = 0
|
||||||
|
@ -768,12 +768,12 @@ class ForwardTTSLoss(nn.Module):
|
||||||
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||||
|
|
||||||
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
||||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
aligner_loss = self.aligner_loss(aligner_logprob, input_lens, decoder_output_lens)
|
||||||
loss = loss + self.aligner_loss_alpha * aligner_loss
|
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||||
|
|
||||||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
|
||||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
|
||||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
if binary_loss_weight:
|
if binary_loss_weight:
|
||||||
return_dict["loss_binary_alignment"] = (
|
return_dict["loss_binary_alignment"] = (
|
||||||
|
@ -784,3 +784,83 @@ class ForwardTTSLoss(nn.Module):
|
||||||
|
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardTTSE2ELoss(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder_loss = ForwardTTSLoss(config)
|
||||||
|
# for generator losses
|
||||||
|
self.mel_loss_alpha = (
|
||||||
|
config.mel_loss_alpha
|
||||||
|
) # mel_loss over the encoder model output as opposed to the vocoder output
|
||||||
|
self.feat_loss_alpha = config.feat_loss_alpha
|
||||||
|
self.gen_loss_alpha = config.gen_loss_alpha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def feature_loss(feats_real, feats_generated):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(feats_real, feats_generated):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
rl = rl.float().detach()
|
||||||
|
gl = gl.float()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generator_loss(scores_fake):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in scores_fake:
|
||||||
|
dg = dg.float()
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_output,
|
||||||
|
decoder_target,
|
||||||
|
decoder_output_lens,
|
||||||
|
dur_output,
|
||||||
|
dur_target,
|
||||||
|
pitch_output,
|
||||||
|
pitch_target,
|
||||||
|
input_lens,
|
||||||
|
aligner_logprob=None,
|
||||||
|
aligner_hard=None,
|
||||||
|
aligner_soft=None,
|
||||||
|
binary_loss_weight=None,
|
||||||
|
feats_fake=None,
|
||||||
|
feats_real=None,
|
||||||
|
scores_fake=None,
|
||||||
|
spec_slice=None,
|
||||||
|
spec_slice_hat=None,
|
||||||
|
):
|
||||||
|
loss_dict = self.encoder_loss(
|
||||||
|
decoder_output=decoder_output,
|
||||||
|
decoder_target=decoder_target,
|
||||||
|
decoder_output_lens=decoder_output_lens,
|
||||||
|
dur_output=dur_output,
|
||||||
|
dur_target=dur_target,
|
||||||
|
pitch_output=pitch_output,
|
||||||
|
pitch_target=pitch_target,
|
||||||
|
input_lens=input_lens,
|
||||||
|
aligner_logprob=aligner_logprob,
|
||||||
|
aligner_hard=aligner_hard,
|
||||||
|
aligner_soft=aligner_soft,
|
||||||
|
binary_loss_weight=binary_loss_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
# vocoder generator losses
|
||||||
|
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||||
|
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||||
|
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat)
|
||||||
|
|
||||||
|
loss_dict["vocoder_loss_mel"] = loss_mel
|
||||||
|
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||||
|
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||||
|
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen
|
||||||
|
return loss_dict
|
||||||
|
|
Loading…
Reference in New Issue