diff --git a/layers/common_layers.py b/layers/common_layers.py index bfdd6775..a652b8a6 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,13 +248,14 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") - if self.location_attention: - self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment + if self.location_attention: + self.update_location_attention(alignment) + context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment