From 243204bc3ed29cb7f300452aace027226bd6566d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 18 May 2018 03:33:01 -0700 Subject: [PATCH] Add location sens attention --- layers/attention.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index f598e182..4326a712 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -34,7 +34,8 @@ class BahdanauAttention(nn.Module): class LocationSensitiveAttention(nn.Module): """Location sensitive attention following https://arxiv.org/pdf/1506.07503.pdf""" - def __init__(self, annot_dim, out_dim, hidden_dim): + def __init__(self, annot_dim, query_dim, hidden_dim): + super(LocationSensitiveAttention, self).__init__() loc_kernel_size = 31 loc_dim = 32 padding = int((loc_kernel_size -1) / 2) @@ -46,13 +47,16 @@ class LocationSensitiveAttention(nn.Module): self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) self.v = nn.Linear(hidden_dim, 1, bias=False) - def forward(self, query, annot, loc): + def forward(self, annot, query, loc): """ Shapes: - - query: (batch, 1, dim) or (batch, dim) - annots: (batch, max_time, dim) + - query: (batch, 1, dim) or (batch, dim) - loc: (batch, 2, max_time) """ + if query.dim() == 2: + # insert time-axis for broadcasting + query = query.unsqueeze(1) loc_conv = self.loc_conv(loc) loc_conv = loc_conv.transpose(1, 2) processed_loc = self.loc_linear(loc_conv) @@ -68,8 +72,7 @@ class AttentionRNN(nn.Module): def __init__(self, out_dim, annot_dim, memory_dim): super(AttentionRNN, self).__init__() self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) - self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim) - self.score_mask_value = score_mask_value + self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim) def forward(self, memory, context, rnn_state, annotations, attention_vec, mask=None, annotations_lengths=None): @@ -85,7 +88,7 @@ class AttentionRNN(nn.Module): # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - alignment = self.alignment_model(annotations, rnn_output) + alignment = self.alignment_model(annotations, rnn_output, attention_vec) # TODO: needs recheck. if mask is not None: