diff --git a/utils/generic_utils.py b/utils/generic_utils.py index d4d27875..83e958b3 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -133,7 +133,8 @@ def lr_decay(init_lr, global_step, warmup_steps): def create_attn_mask(N, T, g=0.05): - r'''creating attn mask for guided attention''' + r'''creating attn mask for guided attention + TODO: vectorize''' M = np.zeros([N, T]) for t in range(T): for n in range(N): @@ -141,7 +142,7 @@ def create_attn_mask(N, T, g=0.05): M[n, t] = val e_x = np.exp(M - np.max(M)) M = e_x / e_x.sum(axis=0) # only difference - M = Variable(torch.FloatTensor(M).t()).cuda() + M = torch.FloatTensor(M).t().cuda() M = torch.stack([M]*32) return M