From a8ed3da6c4ec73b8c38a3fdbe331eeafe8edac54 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 24 Apr 2018 11:34:11 -0700 Subject: [PATCH] guided attn #7 --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 76c6b755..a12e1f46 100644 --- a/train.py +++ b/train.py @@ -128,6 +128,8 @@ def train(model, criterion, data_loader, optimizer, epoch): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[:, :, :n_priority_freq], mel_lengths_var) + print(M.shape) + print(alignments.shape) attention_loss = criterion(M, alignments, mel_lengths_var) loss = mel_loss + linear_loss + 0.2 * attention_loss