From 3348d462d2b0298f5906707f2e8a5c21642323b1 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 17 May 2018 08:03:16 -0700 Subject: [PATCH] Add location sensitive attention --- layers/attention.py | 57 +++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index 51d3542a..9c63a85f 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -29,34 +29,50 @@ class BahdanauAttention(nn.Module): # (batch, max_time) return alignment.squeeze(-1) - - -def get_mask_from_lengths(inputs, inputs_lengths): - """Get mask tensor from list of length - - Args: - inputs: Tensor in size (batch, max_time, dim) - inputs_lengths: array like - """ - mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() - for idx, l in enumerate(inputs_lengths): - mask[idx][:l] = 1 - return ~mask + + +class LocationSensitiveAttention(nn.Module): + """Location sensitive attention following + https://arxiv.org/pdf/1506.07503.pdf""" + def __init__(self, annot_dim, out_dim, hidden_dim): + loc_kernel_size = 31 + loc_dim = 32 + padding = int((loc_kernel_size -1) / 2) + self.loc_conv = nn.Conv1d(2, loc_dim, + kernel_size=loc_kernel_size, stride=1, + padding=padding, bias=False) + self.loc_linear = nn.Linear(loc_dim, hidden_dim) + self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) + self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) + self.v = nn.Linear(hidden_dim, 1, bias=False) + + def forward(self, query, annot, loc): + """ + Shapes: + - query: (batch, 1, dim) or (batch, dim) + - annots: (batch, max_time, dim) + - loc: (batch, 2, max_time) + """ + loc_conv = self.loc_conv(loc) + loc_conv = loc_conv.transpose(1, 2) + processed_loc = self.loc_linear(loc_conv) + processed_query = self.query_layer(query) + processed_annots = self.annot_layer(annot) + alignment = self.v(nn.functional.tanh( + processed_query + processed_annots + processed_loc)) + # (batch, max_time) + return alignment.squeeze(-1) class AttentionRNN(nn.Module): - def __init__(self, out_dim, annot_dim, memory_dim, - score_mask_value=-float("inf")): + def __init__(self, out_dim, annot_dim, memory_dim): super(AttentionRNN, self).__init__() self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) - self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) + self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim) self.score_mask_value = score_mask_value def forward(self, memory, context, rnn_state, annotations, - mask=None, annotations_lengths=None): - - if annotations_lengths is not None and mask is None: - mask = get_mask_from_lengths(annotations, annotations_lengths) + attention_vec, mask=None, annotations_lengths=None): # Concat input query and previous context context rnn_input = torch.cat((memory, context), -1) @@ -85,3 +101,4 @@ class AttentionRNN(nn.Module): context = torch.bmm(alignment.unsqueeze(1), annotations) context = context.squeeze(1) return rnn_output, context, alignment +