diff --git a/layers/attention.py b/layers/attention.py index e0d5e52c..6b9ee47b 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -105,7 +105,7 @@ class AttentionRNN(nn.Module): # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - if attnetion_vec is None: + if self.align_model is 'b': alignment = self.alignment_model(annotations, rnn_output) else: alignment = self.alignment_model(annotations, rnn_output, attention_vec)