From 40f56f9b000bb03384ebe883c03380b260a6a205 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:47:06 +0200 Subject: [PATCH] simplified code for fwd attn --- layers/common_layers.py | 42 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index bc353be3..bfdd6775 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -201,17 +201,17 @@ class Attention(nn.Module): self.win_idx = torch.argmax(attention, 1).long()[0].item() return attention - def apply_forward_attention(self, inputs, alignment, query): + def apply_forward_attention(self, alignment): # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device), - (1, 0, 0, 0)) + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), + (1, 0, 0, 0)) # compute transition potentials alpha = ((1 - self.u) * self.alpha - + self.u * prev_alpha + + self.u * fwd_shifted_alpha + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: - _, n = prev_alpha.max(1) + _, n = fwd_shifted_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): alpha[b, n[b] + 3:] = 0 @@ -221,16 +221,9 @@ class Attention(nn.Module): alpha[b, (n[b] - 2 )] = 0.01 * val[b] # smoothing factor for the prev step - # compute attention weights - self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) - # compute context - context = torch.bmm(self.alpha.unsqueeze(1), inputs) - context = context.squeeze(1) - # compute transition agent - if self.trans_agent: - ta_input = torch.cat([context, query.squeeze(1)], dim=-1) - self.u = torch.sigmoid(self.ta(ta_input)) - return context, self.alpha + # renormalize attention weights + alpha = alpha / alpha.sum(dim=1, keepdim=True) + return alpha def forward(self, query, inputs, processed_inputs, mask): if self.location_attention: @@ -254,15 +247,20 @@ class Attention(nn.Module): attention).sum( dim=1, keepdim=True) else: - raise RuntimeError("Unknown value for attention norm type") + raise ValueError("Unknown value for attention norm type") if self.location_attention: self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - context, self.attention_weights = self.apply_forward_attention( - inputs, alignment, query) - else: - context = torch.bmm(alignment.unsqueeze(1), inputs) - context = context.squeeze(1) - self.attention_weights = alignment + alignment = self.apply_forward_attention(alignment) + self.alpha = alignment + + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + + # compute transition agent + if self.forward_attn and self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) return context