fix forward attention

This commit is contained in:
Eren Golge 2019-04-18 18:36:01 +02:00
parent 01dbfb3a0f
commit 9ba13b2d2f
1 changed files with 5 additions and 5 deletions

View File

@ -152,7 +152,7 @@ class Attention(nn.Module):
"""
B = inputs.shape[0]
T = inputs.shape[1]
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
def get_attention(self, query, processed_inputs, attention_cat):
@ -183,16 +183,16 @@ class Attention(nn.Module):
def apply_forward_attention(self, inputs, alignment, processed_query):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha)) * alignment
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
context = context.squeeze(1)
# compute transition agent
if self.trans_agent:
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context, alpha_norm, alignment
return context, self.alpha, alignment
def forward(self, attention_hidden_state, inputs, processed_inputs,
attention_cat, mask):