Update attention module Possible BUG FIX

This commit is contained in:
Eren Golge 2018-02-05 06:37:40 -08:00
parent 2fd37a5bad
commit b6c5771a6f
1 changed files with 11 additions and 11 deletions

View File

@ -60,17 +60,10 @@ class AttentionWrapper(nn.Module):
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous context_vec context
import ipdb
ipdb.set_trace()
cell_input = torch.cat((query, context_vec), -1)
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.alignment_model(cell_output, processed_inputs)
# e_{ij} = a(s_{i-1}, h_j)
alignment = self.alignment_model(cell_state, processed_inputs)
if mask is not None:
mask = mask.view(query.size(0), -1)
@ -81,11 +74,18 @@ class AttentionWrapper(nn.Module):
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
context_vec = context_vec.squeeze(1)
# Concat input query and previous context_vec context
cell_input = torch.cat((query, context_vec), -1)
cell_input = cell_input.unsqueeze(1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
cell_output = self.rnn_cell(cell_input, cell_state)
return cell_output, context_vec, alignment