diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index bf03671c..15b0a2ee 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -270,7 +270,7 @@ class GlowTTSLoss(torch.nn.Module): pz = torch.sum(scales) + 0.5 * torch.sum( torch.exp(-2 * scales) * (z - means)**2) log_mle = self.constant_factor + (pz - torch.sum(log_det)) / ( - torch.sum(y_lengths // 2) * 2 * z.shape[1]) + torch.sum(y_lengths) * z.shape[1]) # duration loss - MSE # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) # duration loss - huber loss