correct glow-tts loss

This commit is contained in:
erogol 2020-09-27 03:28:42 +02:00
parent 665f7ca714
commit 6a70c63f24
1 changed files with 1 additions and 1 deletions

View File

@ -270,7 +270,7 @@ class GlowTTSLoss(torch.nn.Module):
pz = torch.sum(scales) + 0.5 * torch.sum(
torch.exp(-2 * scales) * (z - means)**2)
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
torch.sum(y_lengths // 2) * 2 * z.shape[1])
torch.sum(y_lengths) * z.shape[1])
# duration loss - MSE
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
# duration loss - huber loss