diff --git a/tf/layers/common_layers.py b/tf/layers/common_layers.py index 195acfed..73c066c1 100644 --- a/tf/layers/common_layers.py +++ b/tf/layers/common_layers.py @@ -155,6 +155,23 @@ class Attention(keras.layers.Layer): score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) return score + def apply_forward_attention(self, alignment, alpha): + # forward attention + fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0))) + # compute transition potentials + new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment + # renormalize attention weights + new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) + return new_alpha + + def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): + states = [] + if self.use_loc_attn: + states = [old_states[0] + scores_norm, attn_weights] + if self.use_forward_attn: + states.append(new_alpha) + return tuple(states) + def call(self, query, states): """ shapes: @@ -170,13 +187,19 @@ class Attention(keras.layers.Layer): # self.apply_score_masking(score, mask) # attn_weights shape == (batch_size, max_length, 1) - attn_weights = self.norm_func(score) + # normalize attention scores + scores_norm = self.norm_func(score) + attn_weights = scores_norm - # update attention states - if self.use_loc_attn: - states = (states[0] + attn_weights, attn_weights) - else: - states = () + # apply forward attention + new_alpha = None + if self.use_forward_attn: + new_alpha = self.apply_forward_attention(attn_weights, states[-1]) + attn_weights = new_alpha + + # update states tuple + # states = (cum_attn_weights, attn_weights, new_alpha) + states = self.update_states(states, scores_norm, attn_weights, new_alpha) # context_vector shape after sum == (batch_size, hidden_size) context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)