This commit is contained in:
Eren Golge 2018-04-25 08:14:07 -07:00
parent 8ef8ddb915
commit 722896f70a
1 changed files with 3 additions and 2 deletions

View File

@ -133,7 +133,8 @@ def lr_decay(init_lr, global_step, warmup_steps):
def create_attn_mask(N, T, g=0.05):
r'''creating attn mask for guided attention'''
r'''creating attn mask for guided attention
TODO: vectorize'''
M = np.zeros([N, T])
for t in range(T):
for n in range(N):
@ -141,7 +142,7 @@ def create_attn_mask(N, T, g=0.05):
M[n, t] = val
e_x = np.exp(M - np.max(M))
M = e_x / e_x.sum(axis=0) # only difference
M = Variable(torch.FloatTensor(M).t()).cuda()
M = torch.FloatTensor(M).t().cuda()
M = torch.stack([M]*32)
return M