Add location sens attention

This commit is contained in:
Eren Golge 2018-05-18 03:33:01 -07:00
parent 7b9fd63649
commit 243204bc3e
1 changed files with 9 additions and 6 deletions

View File

@ -34,7 +34,8 @@ class BahdanauAttention(nn.Module):
class LocationSensitiveAttention(nn.Module): class LocationSensitiveAttention(nn.Module):
"""Location sensitive attention following """Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf""" https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, out_dim, hidden_dim): def __init__(self, annot_dim, query_dim, hidden_dim):
super(LocationSensitiveAttention, self).__init__()
loc_kernel_size = 31 loc_kernel_size = 31
loc_dim = 32 loc_dim = 32
padding = int((loc_kernel_size -1) / 2) padding = int((loc_kernel_size -1) / 2)
@ -46,13 +47,16 @@ class LocationSensitiveAttention(nn.Module):
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False) self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, annot, loc): def forward(self, annot, query, loc):
""" """
Shapes: Shapes:
- query: (batch, 1, dim) or (batch, dim)
- annots: (batch, max_time, dim) - annots: (batch, max_time, dim)
- query: (batch, 1, dim) or (batch, dim)
- loc: (batch, 2, max_time) - loc: (batch, 2, max_time)
""" """
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
loc_conv = self.loc_conv(loc) loc_conv = self.loc_conv(loc)
loc_conv = loc_conv.transpose(1, 2) loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv) processed_loc = self.loc_linear(loc_conv)
@ -68,8 +72,7 @@ class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim): def __init__(self, out_dim, annot_dim, memory_dim):
super(AttentionRNN, self).__init__() super(AttentionRNN, self).__init__()
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim) self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
self.score_mask_value = score_mask_value
def forward(self, memory, context, rnn_state, annotations, def forward(self, memory, context, rnn_state, annotations,
attention_vec, mask=None, annotations_lengths=None): attention_vec, mask=None, annotations_lengths=None):
@ -85,7 +88,7 @@ class AttentionRNN(nn.Module):
# Alignment # Alignment
# (batch, max_time) # (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j) # e_{ij} = a(s_{i-1}, h_j)
alignment = self.alignment_model(annotations, rnn_output) alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck. # TODO: needs recheck.
if mask is not None: if mask is not None: