mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTSE2eLoss
This commit is contained in:
parent
dbe5eb992e
commit
4171f4e9c6
|
@ -8,7 +8,8 @@ from torch.nn import functional
|
|||
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
|
@ -786,16 +787,18 @@ class ForwardTTSLoss(nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
class ForwardTTSE2ELoss(nn.Module):
|
||||
class ForwardTTSE2eLoss(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.encoder_loss = ForwardTTSLoss(config)
|
||||
self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params)
|
||||
# for generator losses
|
||||
self.mel_loss_alpha = (
|
||||
config.mel_loss_alpha
|
||||
) # mel_loss over the encoder model output as opposed to the vocoder output
|
||||
self.feat_loss_alpha = config.feat_loss_alpha
|
||||
self.gen_loss_alpha = config.gen_loss_alpha
|
||||
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
|
@ -829,6 +832,8 @@ class ForwardTTSE2ELoss(nn.Module):
|
|||
pitch_output,
|
||||
pitch_target,
|
||||
input_lens,
|
||||
waveform,
|
||||
waveform_hat,
|
||||
aligner_logprob=None,
|
||||
aligner_hard=None,
|
||||
aligner_soft=None,
|
||||
|
@ -857,10 +862,16 @@ class ForwardTTSE2ELoss(nn.Module):
|
|||
# vocoder generator losses
|
||||
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat)
|
||||
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.mel_loss_alpha
|
||||
loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform)
|
||||
loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha
|
||||
loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha
|
||||
|
||||
loss_dict["vocoder_loss_mel"] = loss_mel
|
||||
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen
|
||||
loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg
|
||||
loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc
|
||||
|
||||
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_stft_sc + loss_stft_mg
|
||||
return loss_dict
|
||||
|
|
Loading…
Reference in New Issue