Cache attention annot vectors for the whole sequence.

This commit is contained in:
Eren Golge 2018-12-11 16:06:02 +01:00
parent 211a20a47a
commit dc3d09304e
2 changed files with 14 additions and 5 deletions

View File

@ -56,6 +56,7 @@ class LocationSensitiveAttention(nn.Module):
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False) self.v = nn.Linear(attn_dim, 1, bias=False)
self.processed_annots = None
# self.init_layers() # self.init_layers()
def init_layers(self): def init_layers(self):
@ -72,6 +73,9 @@ class LocationSensitiveAttention(nn.Module):
self.v.weight, self.v.weight,
gain=torch.nn.init.calculate_gain('linear')) gain=torch.nn.init.calculate_gain('linear'))
def reset(self):
self.processed_annots = None
def forward(self, annot, query, loc): def forward(self, annot, query, loc):
""" """
Shapes: Shapes:
@ -86,9 +90,11 @@ class LocationSensitiveAttention(nn.Module):
loc_conv = loc_conv.transpose(1, 2) loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv) processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot) # cache annots
if self.processed_annots is None:
self.processed_annots = self.annot_layer(annot)
alignment = self.v( alignment = self.v(
torch.tanh(processed_query + processed_annots + processed_loc)) torch.tanh(processed_query + self.processed_annots + processed_loc))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)
@ -120,7 +126,7 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format( 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(
align_model)) align_model))
def forward(self, memory, context, rnn_state, annots, atten, mask): def forward(self, memory, context, rnn_state, annots, atten, mask, t):
""" """
Shapes: Shapes:
- memory: (batch, 1, dim) or (batch, dim) - memory: (batch, 1, dim) or (batch, dim)
@ -130,6 +136,8 @@ class AttentionRNNCell(nn.Module):
- atten: (batch, 2, max_time) - atten: (batch, 2, max_time)
- mask: (batch,) - mask: (batch,)
""" """
if t == 0:
self.alignment_model.reset()
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN # Feed it to RNN
@ -147,6 +155,7 @@ class AttentionRNNCell(nn.Module):
alignment.masked_fill_(1 - mask, -float("inf")) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight
# alignment = F.softmax(alignment, dim=-1) # alignment = F.softmax(alignment, dim=-1)
# alignment = 5 * alignment
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)

View File

@ -402,7 +402,7 @@ class Decoder(nn.Module):
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask) inputs, attention_cat, mask, t)
attention_cum += attention attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(