From 570d5971be3866be9c19e914464b8156b478df9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 08:29:12 +0000 Subject: [PATCH] Implement `ForwardTTSLoss` --- TTS/tts/layers/losses.py | 121 ++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 2c752376..72e7d8d5 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -236,10 +236,40 @@ class Huber(nn.Module): y: B x T length: B """ - mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float() + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float() return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + ######################## # MODEL LOSS LAYERS ######################## @@ -413,25 +443,6 @@ class GlowTTSLoss(torch.nn.Module): return return_dict -class SpeedySpeechLoss(nn.Module): - def __init__(self, c): - super().__init__() - self.l1 = L1LossMasked(False) - self.ssim = SSIMLoss() - self.huber = Huber() - - self.ssim_alpha = c.ssim_alpha - self.huber_alpha = c.huber_alpha - self.l1_alpha = c.l1_alpha - - def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens): - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) - ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - huber_loss = self.huber(dur_output, dur_target, input_lens) - loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss - return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss} - - def mse_loss_custom(x, y): """MSE loss using the torch back-end without reduction. It uses less VRAM than the raw code""" @@ -660,51 +671,41 @@ class VitsDiscriminatorLoss(nn.Module): return return_dict -class ForwardSumLoss(nn.Module): - def __init__(self, blank_logprob=-1): - super().__init__() - self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) - self.blank_logprob = blank_logprob +class ForwardTTSLoss(nn.Module): + """Generic configurable ForwardTTS loss.""" - def forward(self, attn_logprob, in_lens, out_lens): - key_lens = in_lens - query_lens = out_lens - attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) - - total_loss = 0.0 - for bid in range(attn_logprob.shape[0]): - target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) - curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] - - curr_logprob = self.log_softmax(curr_logprob[None])[0] - loss = self.ctc_loss( - curr_logprob, - target_seq, - input_lengths=query_lens[bid : bid + 1], - target_lengths=key_lens[bid : bid + 1], - ) - total_loss = total_loss + loss - - total_loss = total_loss / attn_logprob.shape[0] - return total_loss - - -class FastPitchLoss(nn.Module): def __init__(self, c): super().__init__() - self.spec_loss = MSELossMasked(False) - self.ssim = SSIMLoss() - self.dur_loss = MSELossMasked(False) - self.pitch_loss = MSELossMasked(False) + if c.spec_loss_type == "mse": + self.spec_loss = MSELossMasked(False) + elif c.spec_loss_type == "l1": + self.spec_loss = L1LossMasked(False) + else: + raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type)) + + if c.duration_loss_type == "mse": + self.dur_loss = MSELossMasked(False) + elif c.duration_loss_type == "l1": + self.dur_loss = L1LossMasked(False) + elif c.duration_loss_type == "huber": + self.dur_loss = Huber() + else: + raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type)) + if c.model_args.use_aligner: self.aligner_loss = ForwardSumLoss() + self.aligner_loss_alpha = c.aligner_loss_alpha + + if c.model_args.use_pitch: + self.pitch_loss = MSELossMasked(False) + self.pitch_loss_alpha = c.pitch_loss_alpha + + if c.use_ssim_loss: + self.ssim = SSIMLoss() if c.use_ssim_loss else None + self.ssim_loss_alpha = c.ssim_loss_alpha self.spec_loss_alpha = c.spec_loss_alpha - self.ssim_loss_alpha = c.ssim_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha - self.pitch_loss_alpha = c.pitch_loss_alpha - self.aligner_loss_alpha = c.aligner_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha @staticmethod @@ -731,7 +732,7 @@ class FastPitchLoss(nn.Module): ): loss = 0 return_dict = {} - if self.ssim_loss_alpha > 0: + if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0: ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) loss = loss + self.ssim_loss_alpha * ssim_loss return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss @@ -747,12 +748,12 @@ class FastPitchLoss(nn.Module): loss = loss + self.dur_loss_alpha * dur_loss return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss - if self.pitch_loss_alpha > 0: + if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0: pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) loss = loss + self.pitch_loss_alpha * pitch_loss return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss - if self.aligner_loss_alpha > 0: + if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss