diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 09985c82..b30f566b 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -8,7 +8,8 @@ from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import ssim -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss # pylint: disable=abstract-method @@ -786,16 +787,18 @@ class ForwardTTSLoss(nn.Module): return return_dict -class ForwardTTSE2ELoss(nn.Module): +class ForwardTTSE2eLoss(nn.Module): def __init__(self, config): super().__init__() self.encoder_loss = ForwardTTSLoss(config) + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) # for generator losses self.mel_loss_alpha = ( config.mel_loss_alpha ) # mel_loss over the encoder model output as opposed to the vocoder output self.feat_loss_alpha = config.feat_loss_alpha self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha @staticmethod def feature_loss(feats_real, feats_generated): @@ -829,6 +832,8 @@ class ForwardTTSE2ELoss(nn.Module): pitch_output, pitch_target, input_lens, + waveform, + waveform_hat, aligner_logprob=None, aligner_hard=None, aligner_soft=None, @@ -857,10 +862,16 @@ class ForwardTTSE2ELoss(nn.Module): # vocoder generator losses loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha - loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha loss_dict["vocoder_loss_mel"] = loss_mel loss_dict["vocoder_loss_feat"] = loss_feat loss_dict["vocoder_loss_gen"] = loss_gen - loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_stft_sc + loss_stft_mg return loss_dict