mirror of https://github.com/coqui-ai/TTS.git
forward attention for TF Tacotron2
This commit is contained in:
parent
52473d4853
commit
dfd5e3cbfc
|
@ -155,6 +155,23 @@ class Attention(keras.layers.Layer):
|
|||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
return score
|
||||
|
||||
def apply_forward_attention(self, alignment, alpha):
|
||||
# forward attention
|
||||
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)))
|
||||
# compute transition potentials
|
||||
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
||||
# renormalize attention weights
|
||||
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
||||
return new_alpha
|
||||
|
||||
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
states = [old_states[0] + scores_norm, attn_weights]
|
||||
if self.use_forward_attn:
|
||||
states.append(new_alpha)
|
||||
return tuple(states)
|
||||
|
||||
def call(self, query, states):
|
||||
"""
|
||||
shapes:
|
||||
|
@ -170,13 +187,19 @@ class Attention(keras.layers.Layer):
|
|||
# self.apply_score_masking(score, mask)
|
||||
# attn_weights shape == (batch_size, max_length, 1)
|
||||
|
||||
attn_weights = self.norm_func(score)
|
||||
# normalize attention scores
|
||||
scores_norm = self.norm_func(score)
|
||||
attn_weights = scores_norm
|
||||
|
||||
# update attention states
|
||||
if self.use_loc_attn:
|
||||
states = (states[0] + attn_weights, attn_weights)
|
||||
else:
|
||||
states = ()
|
||||
# apply forward attention
|
||||
new_alpha = None
|
||||
if self.use_forward_attn:
|
||||
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
||||
attn_weights = new_alpha
|
||||
|
||||
# update states tuple
|
||||
# states = (cum_attn_weights, attn_weights, new_alpha)
|
||||
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||
|
||||
# context_vector shape after sum == (batch_size, hidden_size)
|
||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||
|
|
Loading…
Reference in New Issue