From e63962c22662d76c0765fdb35fd0b30fce8888c8 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 00:45:33 +0100 Subject: [PATCH] refactor(losses): move shared losses into losses.py --- TTS/tts/layers/losses.py | 87 +++++++++++++++++++------------- TTS/tts/models/delightful_tts.py | 44 ++++------------ TTS/tts/models/neuralhmm_tts.py | 19 +------ TTS/tts/models/overflow.py | 19 +------ 4 files changed, 64 insertions(+), 105 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 5ebed81d..db62430c 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -309,6 +309,24 @@ class ForwardSumLoss(nn.Module): return total_loss +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict + + ######################## # MODEL LOSS LAYERS ######################## @@ -619,6 +637,28 @@ class AlignTTSLoss(nn.Module): return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} +def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + +def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + class VitsGeneratorLoss(nn.Module): def __init__(self, c: Coqpit): super().__init__() @@ -640,28 +680,6 @@ class VitsGeneratorLoss(nn.Module): do_amp_to_db=True, ) - @staticmethod - def feature_loss(feats_real, feats_generated): - loss = 0 - for dr, dg in zip(feats_real, feats_generated): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 - - @staticmethod - def generator_loss(scores_fake): - loss = 0 - gen_losses = [] - for dg in scores_fake: - dg = dg.float() - l = torch.mean((1 - dg) ** 2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - @staticmethod def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): """ @@ -722,10 +740,8 @@ class VitsGeneratorLoss(nn.Module): self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha ) - loss_feat = ( - self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha - ) - loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration @@ -779,6 +795,15 @@ class VitsDiscriminatorLoss(nn.Module): return return_dict +def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments. + + Explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + class ForwardTTSLoss(nn.Module): """Generic configurable ForwardTTS loss.""" @@ -820,14 +845,6 @@ class ForwardTTSLoss(nn.Module): self.dur_loss_alpha = c.dur_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha - @staticmethod - def _binary_alignment_loss(alignment_hard, alignment_soft): - """Binary loss that forces soft alignments to match the hard alignments as - explained in `https://arxiv.org/pdf/2108.10447.pdf`. - """ - log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() - return -log_sum / alignment_hard.sum() - def forward( self, decoder_output, @@ -879,7 +896,7 @@ class ForwardTTSLoss(nn.Module): return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: return_dict["loss_binary_alignment"] = ( diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 2f34e432..7216e814 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -19,7 +19,13 @@ from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel -from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.losses import ( + ForwardSumLoss, + VitsDiscriminatorLoss, + _binary_alignment_loss, + feature_loss, + generator_loss, +) from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E from TTS.tts.models.vits import load_audio @@ -1491,36 +1497,6 @@ class DelightfulTTSLoss(nn.Module): self.gen_loss_alpha = config.gen_loss_alpha self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha - @staticmethod - def _binary_alignment_loss(alignment_hard, alignment_soft): - """Binary loss that forces soft alignments to match the hard alignments as - explained in `https://arxiv.org/pdf/2108.10447.pdf`. - """ - log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() - return -log_sum / alignment_hard.sum() - - @staticmethod - def feature_loss(feats_real, feats_generated): - loss = 0 - for dr, dg in zip(feats_real, feats_generated): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 - - @staticmethod - def generator_loss(scores_fake): - loss = 0 - gen_losses = [] - for dg in scores_fake: - dg = dg.float() - l = torch.mean((1 - dg) ** 2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - def forward( self, mel_output, @@ -1618,7 +1594,7 @@ class DelightfulTTSLoss(nn.Module): ) if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft) total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight if binary_loss_weight: loss_dict["loss_binary_alignment"] = ( @@ -1638,8 +1614,8 @@ class DelightfulTTSLoss(nn.Module): # vocoder losses if not skip_disc: - loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha - loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha loss_dict["vocoder_loss_feat"] = loss_feat loss_dict["vocoder_loss_gen"] = loss_gen loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index de5401aa..0b3fadaf 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -8,6 +8,7 @@ from torch import nn from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger +from TTS.tts.layers.losses import NLLLoss from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.neural_hmm import NeuralHMM from TTS.tts.layers.overflow.plotting_utils import ( @@ -373,21 +374,3 @@ class NeuralhmmTTS(BaseTTS): ) -> None: logger.test_audios(steps, outputs[1], self.ap.sample_rate) logger.test_figures(steps, outputs[0]) - - -class NLLLoss(nn.Module): - """Negative log likelihood loss.""" - - def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use - """Compute the loss. - - Args: - logits (Tensor): [B, T, D] - - Returns: - Tensor: [1] - - """ - return_dict = {} - return_dict["loss"] = -log_prob.mean() - return return_dict diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index b72f4877..ac09e406 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -8,6 +8,7 @@ from torch import nn from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger +from TTS.tts.layers.losses import NLLLoss from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.decoder import Decoder from TTS.tts.layers.overflow.neural_hmm import NeuralHMM @@ -389,21 +390,3 @@ class Overflow(BaseTTS): ) -> None: logger.test_audios(steps, outputs[1], self.ap.sample_rate) logger.test_figures(steps, outputs[0]) - - -class NLLLoss(nn.Module): - """Negative log likelihood loss.""" - - def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use - """Compute the loss. - - Args: - logits (Tensor): [B, T, D] - - Returns: - Tensor: [1] - - """ - return_dict = {} - return_dict["loss"] = -log_prob.mean() - return return_dict