forward attention for TF Tacotron2

This commit is contained in:
erogol 2020-07-08 10:23:28 +02:00
parent 52473d4853
commit dfd5e3cbfc
1 changed files with 29 additions and 6 deletions

View File

@ -155,6 +155,23 @@ class Attention(keras.layers.Layer):
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def apply_forward_attention(self, alignment, alpha):
# forward attention
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)))
# compute transition potentials
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
# renormalize attention weights
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
return new_alpha
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
states = []
if self.use_loc_attn:
states = [old_states[0] + scores_norm, attn_weights]
if self.use_forward_attn:
states.append(new_alpha)
return tuple(states)
def call(self, query, states):
"""
shapes:
@ -170,13 +187,19 @@ class Attention(keras.layers.Layer):
# self.apply_score_masking(score, mask)
# attn_weights shape == (batch_size, max_length, 1)
attn_weights = self.norm_func(score)
# normalize attention scores
scores_norm = self.norm_func(score)
attn_weights = scores_norm
# update attention states
if self.use_loc_attn:
states = (states[0] + attn_weights, attn_weights)
else:
states = ()
# apply forward attention
new_alpha = None
if self.use_forward_attn:
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
attn_weights = new_alpha
# update states tuple
# states = (cum_attn_weights, attn_weights, new_alpha)
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)