mirror of https://github.com/coqui-ai/TTS.git
refactor(losses): move shared losses into losses.py
This commit is contained in:
parent
6f25c2b904
commit
e63962c226
|
@ -309,6 +309,24 @@ class ForwardSumLoss(nn.Module):
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
|
class NLLLoss(nn.Module):
|
||||||
|
"""Negative log likelihood loss."""
|
||||||
|
|
||||||
|
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||||
|
"""Compute the loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (Tensor): [B, T, D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: [1]
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_dict = {}
|
||||||
|
return_dict["loss"] = -log_prob.mean()
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# MODEL LOSS LAYERS
|
# MODEL LOSS LAYERS
|
||||||
########################
|
########################
|
||||||
|
@ -619,6 +637,28 @@ class AlignTTSLoss(nn.Module):
|
||||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||||
|
|
||||||
|
|
||||||
|
def feature_loss(feats_real, feats_generated):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(feats_real, feats_generated):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
rl = rl.float().detach()
|
||||||
|
gl = gl.float()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
|
def generator_loss(scores_fake):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in scores_fake:
|
||||||
|
dg = dg.float()
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
|
||||||
class VitsGeneratorLoss(nn.Module):
|
class VitsGeneratorLoss(nn.Module):
|
||||||
def __init__(self, c: Coqpit):
|
def __init__(self, c: Coqpit):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -640,28 +680,6 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
do_amp_to_db=True,
|
do_amp_to_db=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def feature_loss(feats_real, feats_generated):
|
|
||||||
loss = 0
|
|
||||||
for dr, dg in zip(feats_real, feats_generated):
|
|
||||||
for rl, gl in zip(dr, dg):
|
|
||||||
rl = rl.float().detach()
|
|
||||||
gl = gl.float()
|
|
||||||
loss += torch.mean(torch.abs(rl - gl))
|
|
||||||
return loss * 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generator_loss(scores_fake):
|
|
||||||
loss = 0
|
|
||||||
gen_losses = []
|
|
||||||
for dg in scores_fake:
|
|
||||||
dg = dg.float()
|
|
||||||
l = torch.mean((1 - dg) ** 2)
|
|
||||||
gen_losses.append(l)
|
|
||||||
loss += l
|
|
||||||
|
|
||||||
return loss, gen_losses
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||||
"""
|
"""
|
||||||
|
@ -722,10 +740,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
||||||
* self.kl_loss_alpha
|
* self.kl_loss_alpha
|
||||||
)
|
)
|
||||||
loss_feat = (
|
loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||||
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
)
|
|
||||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
|
||||||
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
||||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
@ -779,6 +795,15 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||||
|
"""Binary loss that forces soft alignments to match the hard alignments.
|
||||||
|
|
||||||
|
Explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||||
|
"""
|
||||||
|
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||||
|
return -log_sum / alignment_hard.sum()
|
||||||
|
|
||||||
|
|
||||||
class ForwardTTSLoss(nn.Module):
|
class ForwardTTSLoss(nn.Module):
|
||||||
"""Generic configurable ForwardTTS loss."""
|
"""Generic configurable ForwardTTS loss."""
|
||||||
|
|
||||||
|
@ -820,14 +845,6 @@ class ForwardTTSLoss(nn.Module):
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
|
||||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
|
||||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
|
||||||
"""
|
|
||||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
|
||||||
return -log_sum / alignment_hard.sum()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_output,
|
decoder_output,
|
||||||
|
@ -879,7 +896,7 @@ class ForwardTTSLoss(nn.Module):
|
||||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||||
|
|
||||||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft)
|
||||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
if binary_loss_weight:
|
if binary_loss_weight:
|
||||||
return_dict["loss_binary_alignment"] = (
|
return_dict["loss_binary_alignment"] = (
|
||||||
|
|
|
@ -19,7 +19,13 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
||||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||||
from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss
|
from TTS.tts.layers.losses import (
|
||||||
|
ForwardSumLoss,
|
||||||
|
VitsDiscriminatorLoss,
|
||||||
|
_binary_alignment_loss,
|
||||||
|
feature_loss,
|
||||||
|
generator_loss,
|
||||||
|
)
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||||
from TTS.tts.models.vits import load_audio
|
from TTS.tts.models.vits import load_audio
|
||||||
|
@ -1491,36 +1497,6 @@ class DelightfulTTSLoss(nn.Module):
|
||||||
self.gen_loss_alpha = config.gen_loss_alpha
|
self.gen_loss_alpha = config.gen_loss_alpha
|
||||||
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
|
||||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
|
||||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
|
||||||
"""
|
|
||||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
|
||||||
return -log_sum / alignment_hard.sum()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def feature_loss(feats_real, feats_generated):
|
|
||||||
loss = 0
|
|
||||||
for dr, dg in zip(feats_real, feats_generated):
|
|
||||||
for rl, gl in zip(dr, dg):
|
|
||||||
rl = rl.float().detach()
|
|
||||||
gl = gl.float()
|
|
||||||
loss += torch.mean(torch.abs(rl - gl))
|
|
||||||
return loss * 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generator_loss(scores_fake):
|
|
||||||
loss = 0
|
|
||||||
gen_losses = []
|
|
||||||
for dg in scores_fake:
|
|
||||||
dg = dg.float()
|
|
||||||
l = torch.mean((1 - dg) ** 2)
|
|
||||||
gen_losses.append(l)
|
|
||||||
loss += l
|
|
||||||
|
|
||||||
return loss, gen_losses
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
mel_output,
|
mel_output,
|
||||||
|
@ -1618,7 +1594,7 @@ class DelightfulTTSLoss(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
|
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
|
||||||
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
|
binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft)
|
||||||
total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||||
if binary_loss_weight:
|
if binary_loss_weight:
|
||||||
loss_dict["loss_binary_alignment"] = (
|
loss_dict["loss_binary_alignment"] = (
|
||||||
|
@ -1638,8 +1614,8 @@ class DelightfulTTSLoss(nn.Module):
|
||||||
|
|
||||||
# vocoder losses
|
# vocoder losses
|
||||||
if not skip_disc:
|
if not skip_disc:
|
||||||
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||||
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||||
loss_dict["vocoder_loss_feat"] = loss_feat
|
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||||
loss_dict["vocoder_loss_gen"] = loss_gen
|
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||||
loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen
|
loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
||||||
from trainer.io import load_fsspec
|
from trainer.io import load_fsspec
|
||||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
|
from TTS.tts.layers.losses import NLLLoss
|
||||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||||
from TTS.tts.layers.overflow.plotting_utils import (
|
from TTS.tts.layers.overflow.plotting_utils import (
|
||||||
|
@ -373,21 +374,3 @@ class NeuralhmmTTS(BaseTTS):
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||||
logger.test_figures(steps, outputs[0])
|
logger.test_figures(steps, outputs[0])
|
||||||
|
|
||||||
|
|
||||||
class NLLLoss(nn.Module):
|
|
||||||
"""Negative log likelihood loss."""
|
|
||||||
|
|
||||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
|
||||||
"""Compute the loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (Tensor): [B, T, D]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: [1]
|
|
||||||
|
|
||||||
"""
|
|
||||||
return_dict = {}
|
|
||||||
return_dict["loss"] = -log_prob.mean()
|
|
||||||
return return_dict
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
||||||
from trainer.io import load_fsspec
|
from trainer.io import load_fsspec
|
||||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
|
from TTS.tts.layers.losses import NLLLoss
|
||||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||||
from TTS.tts.layers.overflow.decoder import Decoder
|
from TTS.tts.layers.overflow.decoder import Decoder
|
||||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||||
|
@ -389,21 +390,3 @@ class Overflow(BaseTTS):
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||||
logger.test_figures(steps, outputs[0])
|
logger.test_figures(steps, outputs[0])
|
||||||
|
|
||||||
|
|
||||||
class NLLLoss(nn.Module):
|
|
||||||
"""Negative log likelihood loss."""
|
|
||||||
|
|
||||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
|
||||||
"""Compute the loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (Tensor): [B, T, D]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: [1]
|
|
||||||
|
|
||||||
"""
|
|
||||||
return_dict = {}
|
|
||||||
return_dict["loss"] = -log_prob.mean()
|
|
||||||
return return_dict
|
|
||||||
|
|
Loading…
Reference in New Issue