From ab42396fbfbd647f8f7f67f660250d9f75219643 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Thu, 25 Jul 2019 13:04:41 +0200 Subject: [PATCH] undo loc attn after fwd attn --- layers/common_layers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index a652b8a6..0a7216ef 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,14 +248,15 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment - if self.location_attention: - self.update_location_attention(alignment) - context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment