diff --git a/layers/attention.py b/layers/attention.py index 31a3a23d..958c5701 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -63,6 +63,8 @@ class AttentionWrapper(nn.Module): # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) + # import ipdb + # ipdb.set_trace() alignment = self.alignment_model(cell_state, processed_inputs) if mask is not None: @@ -80,12 +82,13 @@ class AttentionWrapper(nn.Module): # Concat input query and previous context_vec context cell_input = torch.cat((query, context_vec), -1) - cell_input = cell_input.unsqueeze(1) + #cell_input = cell_input.unsqueeze(1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) cell_output = self.rnn_cell(cell_input, cell_state) + context_vec = context_vec.squeeze(1) return cell_output, context_vec, alignment