From 134f5455df71ec561e9126d54f899481201f55df Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 24 Apr 2018 11:32:22 -0700 Subject: [PATCH] guided attn #6 --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 7756ab37..76c6b755 100644 --- a/train.py +++ b/train.py @@ -115,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch): M[n, t] = val e_x = np.exp(M - np.max(M)) M = e_x / e_x.sum(axis=0) # only difference - M = Variable(torch.FloatTensor(M)) + M = Variable(torch.FloatTensor(M)).cuda() M = torch.stack([M]*32) # forward pass