simplified code for fwd attn

This commit is contained in:
Thomas Werkmeister 2019-07-24 11:47:06 +02:00
parent a6118564d5
commit 40f56f9b00
1 changed files with 20 additions and 22 deletions

View File

@ -201,17 +201,17 @@ class Attention(nn.Module):
self.win_idx = torch.argmax(attention, 1).long()[0].item()
return attention
def apply_forward_attention(self, inputs, alignment, query):
def apply_forward_attention(self, alignment):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
(1, 0, 0, 0))
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device),
(1, 0, 0, 0))
# compute transition potentials
alpha = ((1 - self.u) * self.alpha
+ self.u * prev_alpha
+ self.u * fwd_shifted_alpha
+ 1e-8) * alignment
# force incremental alignment
if not self.training and self.forward_attn_mask:
_, n = prev_alpha.max(1)
_, n = fwd_shifted_alpha.max(1)
val, n2 = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n[b] + 3:] = 0
@ -221,16 +221,9 @@ class Attention(nn.Module):
alpha[b,
(n[b] - 2
)] = 0.01 * val[b] # smoothing factor for the prev step
# compute attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
context = context.squeeze(1)
# compute transition agent
if self.trans_agent:
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context, self.alpha
# renormalize attention weights
alpha = alpha / alpha.sum(dim=1, keepdim=True)
return alpha
def forward(self, query, inputs, processed_inputs, mask):
if self.location_attention:
@ -254,15 +247,20 @@ class Attention(nn.Module):
attention).sum(
dim=1, keepdim=True)
else:
raise RuntimeError("Unknown value for attention norm type")
raise ValueError("Unknown value for attention norm type")
if self.location_attention:
self.update_location_attention(alignment)
# apply forward attention if enabled
if self.forward_attn:
context, self.attention_weights = self.apply_forward_attention(
inputs, alignment, query)
else:
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
self.attention_weights = alignment
alignment = self.apply_forward_attention(alignment)
self.alpha = alignment
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
self.attention_weights = alignment
# compute transition agent
if self.forward_attn and self.trans_agent:
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context