Add location sensitive attention

This commit is contained in:
Eren Golge 2018-05-17 08:03:16 -07:00
parent e6112f7b2d
commit 3348d462d2
1 changed files with 37 additions and 20 deletions

View File

@ -31,32 +31,48 @@ class BahdanauAttention(nn.Module):
return alignment.squeeze(-1) return alignment.squeeze(-1)
def get_mask_from_lengths(inputs, inputs_lengths): class LocationSensitiveAttention(nn.Module):
"""Get mask tensor from list of length """Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, out_dim, hidden_dim):
loc_kernel_size = 31
loc_dim = 32
padding = int((loc_kernel_size -1) / 2)
self.loc_conv = nn.Conv1d(2, loc_dim,
kernel_size=loc_kernel_size, stride=1,
padding=padding, bias=False)
self.loc_linear = nn.Linear(loc_dim, hidden_dim)
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False)
Args: def forward(self, query, annot, loc):
inputs: Tensor in size (batch, max_time, dim) """
inputs_lengths: array like Shapes:
""" - query: (batch, 1, dim) or (batch, dim)
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() - annots: (batch, max_time, dim)
for idx, l in enumerate(inputs_lengths): - loc: (batch, 2, max_time)
mask[idx][:l] = 1 """
return ~mask loc_conv = self.loc_conv(loc)
loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
alignment = self.v(nn.functional.tanh(
processed_query + processed_annots + processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)
class AttentionRNN(nn.Module): class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim, def __init__(self, out_dim, annot_dim, memory_dim):
score_mask_value=-float("inf")):
super(AttentionRNN, self).__init__() super(AttentionRNN, self).__init__()
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim)
self.score_mask_value = score_mask_value self.score_mask_value = score_mask_value
def forward(self, memory, context, rnn_state, annotations, def forward(self, memory, context, rnn_state, annotations,
mask=None, annotations_lengths=None): attention_vec, mask=None, annotations_lengths=None):
if annotations_lengths is not None and mask is None:
mask = get_mask_from_lengths(annotations, annotations_lengths)
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
@ -85,3 +101,4 @@ class AttentionRNN(nn.Module):
context = torch.bmm(alignment.unsqueeze(1), annotations) context = torch.bmm(alignment.unsqueeze(1), annotations)
context = context.squeeze(1) context = context.squeeze(1)
return rnn_output, context, alignment return rnn_output, context, alignment