Update ForwardTTSE2eLoss

This commit is contained in:
Eren Gölge 2022-04-19 09:22:50 +02:00 committed by Eren G??lge
parent dbe5eb992e
commit 4171f4e9c6
1 changed files with 15 additions and 4 deletions

View File

@ -8,7 +8,8 @@ from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
# pylint: disable=abstract-method
@ -786,16 +787,18 @@ class ForwardTTSLoss(nn.Module):
return return_dict
class ForwardTTSE2ELoss(nn.Module):
class ForwardTTSE2eLoss(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder_loss = ForwardTTSLoss(config)
self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params)
# for generator losses
self.mel_loss_alpha = (
config.mel_loss_alpha
) # mel_loss over the encoder model output as opposed to the vocoder output
self.feat_loss_alpha = config.feat_loss_alpha
self.gen_loss_alpha = config.gen_loss_alpha
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
@staticmethod
def feature_loss(feats_real, feats_generated):
@ -829,6 +832,8 @@ class ForwardTTSE2ELoss(nn.Module):
pitch_output,
pitch_target,
input_lens,
waveform,
waveform_hat,
aligner_logprob=None,
aligner_hard=None,
aligner_soft=None,
@ -857,10 +862,16 @@ class ForwardTTSE2ELoss(nn.Module):
# vocoder generator losses
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat)
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.mel_loss_alpha
loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform)
loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha
loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha
loss_dict["vocoder_loss_mel"] = loss_mel
loss_dict["vocoder_loss_feat"] = loss_feat
loss_dict["vocoder_loss_gen"] = loss_gen
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen
loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg
loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_stft_sc + loss_stft_mg
return loss_dict