mirror of https://github.com/coqui-ai/TTS.git
refactor(losses): move shared losses into losses.py
This commit is contained in:
parent
6f25c2b904
commit
e63962c226
|
@ -309,6 +309,24 @@ class ForwardSumLoss(nn.Module):
|
|||
return total_loss
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
||||
|
||||
########################
|
||||
# MODEL LOSS LAYERS
|
||||
########################
|
||||
|
@ -619,6 +637,28 @@ class AlignTTSLoss(nn.Module):
|
|||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
class VitsGeneratorLoss(nn.Module):
|
||||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
|
@ -640,28 +680,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
do_amp_to_db=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
@staticmethod
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
|
@ -722,10 +740,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
||||
* self.kl_loss_alpha
|
||||
)
|
||||
loss_feat = (
|
||||
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
)
|
||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
@ -779,6 +795,15 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments.
|
||||
|
||||
Explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
|
||||
class ForwardTTSLoss(nn.Module):
|
||||
"""Generic configurable ForwardTTS loss."""
|
||||
|
||||
|
@ -820,14 +845,6 @@ class ForwardTTSLoss(nn.Module):
|
|||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_output,
|
||||
|
@ -879,7 +896,7 @@ class ForwardTTSLoss(nn.Module):
|
|||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||
|
||||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
if binary_loss_weight:
|
||||
return_dict["loss_binary_alignment"] = (
|
||||
|
|
|
@ -19,7 +19,13 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
|
|||
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||
from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss
|
||||
from TTS.tts.layers.losses import (
|
||||
ForwardSumLoss,
|
||||
VitsDiscriminatorLoss,
|
||||
_binary_alignment_loss,
|
||||
feature_loss,
|
||||
generator_loss,
|
||||
)
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||
from TTS.tts.models.vits import load_audio
|
||||
|
@ -1491,36 +1497,6 @@ class DelightfulTTSLoss(nn.Module):
|
|||
self.gen_loss_alpha = config.gen_loss_alpha
|
||||
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mel_output,
|
||||
|
@ -1618,7 +1594,7 @@ class DelightfulTTSLoss(nn.Module):
|
|||
)
|
||||
|
||||
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
|
||||
binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft)
|
||||
total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
if binary_loss_weight:
|
||||
loss_dict["loss_binary_alignment"] = (
|
||||
|
@ -1638,8 +1614,8 @@ class DelightfulTTSLoss(nn.Module):
|
|||
|
||||
# vocoder losses
|
||||
if not skip_disc:
|
||||
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||
loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||
loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||
loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
|||
from trainer.io import load_fsspec
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.layers.losses import NLLLoss
|
||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||
from TTS.tts.layers.overflow.plotting_utils import (
|
||||
|
@ -373,21 +374,3 @@ class NeuralhmmTTS(BaseTTS):
|
|||
) -> None:
|
||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs[0])
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
|||
from trainer.io import load_fsspec
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.layers.losses import NLLLoss
|
||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||
from TTS.tts.layers.overflow.decoder import Decoder
|
||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||
|
@ -389,21 +390,3 @@ class Overflow(BaseTTS):
|
|||
) -> None:
|
||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs[0])
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
|
Loading…
Reference in New Issue