From 7bf53721e0c6a764b45dacd926b9fddc77bd2da0 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 25 Apr 2018 08:14:07 -0700 Subject: [PATCH] bug fix --- utils/generic_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index d4d27875..83e958b3 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -133,7 +133,8 @@ def lr_decay(init_lr, global_step, warmup_steps): def create_attn_mask(N, T, g=0.05): - r'''creating attn mask for guided attention''' + r'''creating attn mask for guided attention + TODO: vectorize''' M = np.zeros([N, T]) for t in range(T): for n in range(N): @@ -141,7 +142,7 @@ def create_attn_mask(N, T, g=0.05): M[n, t] = val e_x = np.exp(M - np.max(M)) M = e_x / e_x.sum(axis=0) # only difference - M = Variable(torch.FloatTensor(M).t()).cuda() + M = torch.FloatTensor(M).t().cuda() M = torch.stack([M]*32) return M