Implement ForwardTTSE2E Loss

This commit is contained in:
Eren Gölge 2022-04-04 09:42:50 +02:00
parent 95b52a65af
commit 29216ff907
1 changed files with 86 additions and 6 deletions

View File

@ -739,9 +739,9 @@ class ForwardTTSLoss(nn.Module):
pitch_output,
pitch_target,
input_lens,
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
aligner_logprob=None,
aligner_hard=None,
aligner_soft=None,
binary_loss_weight=None,
):
loss = 0
@ -768,12 +768,12 @@ class ForwardTTSLoss(nn.Module):
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
aligner_loss = self.aligner_loss(aligner_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight:
return_dict["loss_binary_alignment"] = (
@ -784,3 +784,83 @@ class ForwardTTSLoss(nn.Module):
return_dict["loss"] = loss
return return_dict
class ForwardTTSE2ELoss(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder_loss = ForwardTTSLoss(config)
# for generator losses
self.mel_loss_alpha = (
config.mel_loss_alpha
) # mel_loss over the encoder model output as opposed to the vocoder output
self.feat_loss_alpha = config.feat_loss_alpha
self.gen_loss_alpha = config.gen_loss_alpha
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def forward(
self,
decoder_output,
decoder_target,
decoder_output_lens,
dur_output,
dur_target,
pitch_output,
pitch_target,
input_lens,
aligner_logprob=None,
aligner_hard=None,
aligner_soft=None,
binary_loss_weight=None,
feats_fake=None,
feats_real=None,
scores_fake=None,
spec_slice=None,
spec_slice_hat=None,
):
loss_dict = self.encoder_loss(
decoder_output=decoder_output,
decoder_target=decoder_target,
decoder_output_lens=decoder_output_lens,
dur_output=dur_output,
dur_target=dur_target,
pitch_output=pitch_output,
pitch_target=pitch_target,
input_lens=input_lens,
aligner_logprob=aligner_logprob,
aligner_hard=aligner_hard,
aligner_soft=aligner_soft,
binary_loss_weight=binary_loss_weight,
)
# vocoder generator losses
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat)
loss_dict["vocoder_loss_mel"] = loss_mel
loss_dict["vocoder_loss_feat"] = loss_feat
loss_dict["vocoder_loss_gen"] = loss_gen
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen
return loss_dict