diff --git a/layers/common_layers.py b/layers/common_layers.py index a652b8a6..0a7216ef 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,14 +248,15 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment - if self.location_attention: - self.update_location_attention(alignment) - context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment