guided attn #12

This commit is contained in:
Eren Golge 2018-04-24 11:41:39 -07:00
parent e87d7d8d26
commit 3161a16101
1 changed files with 1 additions and 1 deletions

View File

@ -130,7 +130,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
mel_lengths_var)
print(M.shape)
print(alignments.shape)
attention_loss = criterion(M, alignments, mel_lengths_var)
attention_loss = criterion(alignments, M, mel_lengths_var)
loss = mel_loss + linear_loss + 0.2 * attention_loss
# backpass and check the grad norm