mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTSE2eLoss
This commit is contained in:
parent
dbe5eb992e
commit
4171f4e9c6
|
@ -8,7 +8,8 @@ from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
from TTS.utils.audio import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
|
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
@ -786,16 +787,18 @@ class ForwardTTSLoss(nn.Module):
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
class ForwardTTSE2ELoss(nn.Module):
|
class ForwardTTSE2eLoss(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_loss = ForwardTTSLoss(config)
|
self.encoder_loss = ForwardTTSLoss(config)
|
||||||
|
self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params)
|
||||||
# for generator losses
|
# for generator losses
|
||||||
self.mel_loss_alpha = (
|
self.mel_loss_alpha = (
|
||||||
config.mel_loss_alpha
|
config.mel_loss_alpha
|
||||||
) # mel_loss over the encoder model output as opposed to the vocoder output
|
) # mel_loss over the encoder model output as opposed to the vocoder output
|
||||||
self.feat_loss_alpha = config.feat_loss_alpha
|
self.feat_loss_alpha = config.feat_loss_alpha
|
||||||
self.gen_loss_alpha = config.gen_loss_alpha
|
self.gen_loss_alpha = config.gen_loss_alpha
|
||||||
|
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def feature_loss(feats_real, feats_generated):
|
def feature_loss(feats_real, feats_generated):
|
||||||
|
@ -829,6 +832,8 @@ class ForwardTTSE2ELoss(nn.Module):
|
||||||
pitch_output,
|
pitch_output,
|
||||||
pitch_target,
|
pitch_target,
|
||||||
input_lens,
|
input_lens,
|
||||||
|
waveform,
|
||||||
|
waveform_hat,
|
||||||
aligner_logprob=None,
|
aligner_logprob=None,
|
||||||
aligner_hard=None,
|
aligner_hard=None,
|
||||||
aligner_soft=None,
|
aligner_soft=None,
|
||||||
|
@ -857,10 +862,16 @@ class ForwardTTSE2ELoss(nn.Module):
|
||||||
# vocoder generator losses
|
# vocoder generator losses
|
||||||
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||||
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||||
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat)
|
loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.mel_loss_alpha
|
||||||
|
loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform)
|
||||||
|
loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha
|
||||||
|
loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha
|
||||||
|
|
||||||
loss_dict["vocoder_loss_mel"] = loss_mel
|
loss_dict["vocoder_loss_mel"] = loss_mel
|
||||||
loss_dict["vocoder_loss_feat"] = loss_feat
|
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||||
loss_dict["vocoder_loss_gen"] = loss_gen
|
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||||
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen
|
loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg
|
||||||
|
loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc
|
||||||
|
|
||||||
|
loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_stft_sc + loss_stft_mg
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
Loading…
Reference in New Issue