From 60ae915156ad842a6af70d1c5cf1f53d7681df83 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sun, 25 Mar 2018 12:01:41 -0700 Subject: [PATCH] normal attention --- layers/attention.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index e7385149..1f83c169 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -57,11 +57,19 @@ class AttentionRNN(nn.Module): if annotations_lengths is not None and mask is None: mask = get_mask_from_lengths(annotations, annotations_lengths) + + # Concat input query and previous context context + rnn_input = torch.cat((memory, context), -1) + #rnn_input = rnn_input.unsqueeze(1) + + # Feed it to RNN + # s_i = f(y_{i-1}, c_{i}, s_{i-1}) + rnn_output = self.rnn_cell(rnn_input, rnn_state) # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - alignment = self.alignment_model(annotations, rnn_state) + alignment = self.alignment_model(annotations, rnn_output) # TODO: needs recheck. if mask is not None: @@ -75,16 +83,6 @@ class AttentionRNN(nn.Module): # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j context = torch.bmm(alignment.unsqueeze(1), annotations) - context = context.squeeze(1) - - # Concat input query and previous context context - rnn_input = torch.cat((memory, context), -1) - #rnn_input = rnn_input.unsqueeze(1) - - # Feed it to RNN - # s_i = f(y_{i-1}, c_{i}, s_{i-1}) - rnn_output = self.rnn_cell(rnn_input, rnn_state) - context = context.squeeze(1) return rnn_output, context, alignment