guided attn #4

This commit is contained in:
Eren Golge 2018-04-24 11:30:58 -07:00
parent 937d1140f4
commit 9d6c08ab72
1 changed files with 1 additions and 1 deletions

View File

@ -115,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
M[n, t] = val
e_x = np.exp(M - np.max(M))
M = e_x / e_x.sum(axis=0) # only difference
M = Variable(M)
M = Variable(torch.FloatTensor(M))
M = torch.stack([M]*32)
# forward pass