A major bug fix for location sensitive attention.

This commit is contained in:
Eren Golge 2018-05-23 06:04:28 -07:00
parent 6bcec24d13
commit 7acf4eab94
2 changed files with 22 additions and 12 deletions

View File

@ -69,10 +69,18 @@ class LocationSensitiveAttention(nn.Module):
class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim):
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
super(AttentionRNN, self).__init__()
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
else:
raise RuntimeError(" Wrong alignment model name: {}. Use\
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations,
attention_vec, mask=None, annotations_lengths=None):
@ -88,7 +96,10 @@ class AttentionRNN(nn.Module):
# Alignment
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
if attnetion_vec is None:
alignment = self.alignment_model(annotations, rnn_output)
else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck.
if mask is not None:

View File

@ -212,12 +212,10 @@ class Decoder(nn.Module):
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__()
self.mode = mode
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.eps = eps
self.r = r
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
@ -268,8 +266,8 @@ class Decoder(nn.Module):
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
attention_vec = inputs.data.new(B, T).zero_()
attention_vec_cum = inputs.data.new(B, T).zero_()
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
@ -287,12 +285,12 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
attention_vec_cum.unsqueeze(1) / (t + 1)),
attention_cat = torch.cat((attention.unsqueeze(1),
attention_cum.unsqueeze(1) / (t + 1)),
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
attention_vec_cum += attention_vec
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat)
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))
@ -333,6 +331,7 @@ class Decoder(nn.Module):
class StopNet(nn.Module):
def __init__(self, r, memory_dim):
"""Predicts the stop token to stop the decoder at testing time"""
super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
self.relu = nn.ReLU()