updating location attn after calculating fwd attention

This commit is contained in:
Thomas Werkmeister 2019-07-24 11:49:07 +02:00
parent 40f56f9b00
commit f3dac0aa84
1 changed files with 3 additions and 2 deletions

View File

@ -248,13 +248,14 @@ class Attention(nn.Module):
dim=1, keepdim=True)
else:
raise ValueError("Unknown value for attention norm type")
if self.location_attention:
self.update_location_attention(alignment)
# apply forward attention if enabled
if self.forward_attn:
alignment = self.apply_forward_attention(alignment)
self.alpha = alignment
if self.location_attention:
self.update_location_attention(alignment)
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
self.attention_weights = alignment