diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 10ee3905..efd0c2cb 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -240,10 +240,12 @@ class TacotronLoss(torch.nn.Module): super(TacotronLoss, self).__init__() self.stopnet_pos_weight = stopnet_pos_weight self.ga_alpha = c.ga_alpha - self.diff_spec_alpha = c.diff_spec_alpha + self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha + self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha self.decoder_alpha = c.decoder_loss_alpha self.postnet_alpha = c.postnet_loss_alpha - self.ssim_alpha = c.ssim_alpha + self.decoder_ssim_alpha = c.decoder_ssim_alpha + self.postnet_ssim_alpha = c.postnet_ssim_alpha self.config = c # postnet and decoder loss