From f3dac0aa840a893f7222b6444d5bc7f5f40d623d Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:49:07 +0200 Subject: [PATCH] updating location attn after calculating fwd attention --- layers/common_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index bfdd6775..a652b8a6 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,13 +248,14 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") - if self.location_attention: - self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment + if self.location_attention: + self.update_location_attention(alignment) + context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment