refactor(losses): move shared losses into losses.py

This commit is contained in:
Enno Hermann 2024-11-22 00:45:33 +01:00
parent 6f25c2b904
commit e63962c226
4 changed files with 64 additions and 105 deletions

View File

@ -309,6 +309,24 @@ class ForwardSumLoss(nn.Module):
return total_loss
class NLLLoss(nn.Module):
"""Negative log likelihood loss."""
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = -log_prob.mean()
return return_dict
########################
# MODEL LOSS LAYERS
########################
@ -619,6 +637,28 @@ class AlignTTSLoss(nn.Module):
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
@ -640,28 +680,6 @@ class VitsGeneratorLoss(nn.Module):
do_amp_to_db=True,
)
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
@staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
@ -722,10 +740,8 @@ class VitsGeneratorLoss(nn.Module):
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
* self.kl_loss_alpha
)
loss_feat = (
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
)
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
@ -779,6 +795,15 @@ class VitsDiscriminatorLoss(nn.Module):
return return_dict
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments.
Explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
class ForwardTTSLoss(nn.Module):
"""Generic configurable ForwardTTS loss."""
@ -820,14 +845,6 @@ class ForwardTTSLoss(nn.Module):
self.dur_loss_alpha = c.dur_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
@staticmethod
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
def forward(
self,
decoder_output,
@ -879,7 +896,7 @@ class ForwardTTSLoss(nn.Module):
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight:
return_dict["loss_binary_alignment"] = (

View File

@ -19,7 +19,13 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss
from TTS.tts.layers.losses import (
ForwardSumLoss,
VitsDiscriminatorLoss,
_binary_alignment_loss,
feature_loss,
generator_loss,
)
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.models.base_tts import BaseTTSE2E
from TTS.tts.models.vits import load_audio
@ -1491,36 +1497,6 @@ class DelightfulTTSLoss(nn.Module):
self.gen_loss_alpha = config.gen_loss_alpha
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
@staticmethod
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def forward(
self,
mel_output,
@ -1618,7 +1594,7 @@ class DelightfulTTSLoss(nn.Module):
)
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft)
total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
if binary_loss_weight:
loss_dict["loss_binary_alignment"] = (
@ -1638,8 +1614,8 @@ class DelightfulTTSLoss(nn.Module):
# vocoder losses
if not skip_disc:
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
loss_dict["vocoder_loss_feat"] = loss_feat
loss_dict["vocoder_loss_gen"] = loss_gen
loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen

View File

@ -8,6 +8,7 @@ from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger
from TTS.tts.layers.losses import NLLLoss
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
from TTS.tts.layers.overflow.plotting_utils import (
@ -373,21 +374,3 @@ class NeuralhmmTTS(BaseTTS):
) -> None:
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
logger.test_figures(steps, outputs[0])
class NLLLoss(nn.Module):
"""Negative log likelihood loss."""
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = -log_prob.mean()
return return_dict

View File

@ -8,6 +8,7 @@ from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger
from TTS.tts.layers.losses import NLLLoss
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
from TTS.tts.layers.overflow.decoder import Decoder
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
@ -389,21 +390,3 @@ class Overflow(BaseTTS):
) -> None:
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
logger.test_figures(steps, outputs[0])
class NLLLoss(nn.Module):
"""Negative log likelihood loss."""
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = -log_prob.mean()
return return_dict