mirror of https://github.com/coqui-ai/TTS.git
forward attention for TF Tacotron2
This commit is contained in:
parent
52473d4853
commit
dfd5e3cbfc
|
@ -155,6 +155,23 @@ class Attention(keras.layers.Layer):
|
||||||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
def apply_forward_attention(self, alignment, alpha):
|
||||||
|
# forward attention
|
||||||
|
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)))
|
||||||
|
# compute transition potentials
|
||||||
|
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
||||||
|
# renormalize attention weights
|
||||||
|
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
||||||
|
return new_alpha
|
||||||
|
|
||||||
|
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
||||||
|
states = []
|
||||||
|
if self.use_loc_attn:
|
||||||
|
states = [old_states[0] + scores_norm, attn_weights]
|
||||||
|
if self.use_forward_attn:
|
||||||
|
states.append(new_alpha)
|
||||||
|
return tuple(states)
|
||||||
|
|
||||||
def call(self, query, states):
|
def call(self, query, states):
|
||||||
"""
|
"""
|
||||||
shapes:
|
shapes:
|
||||||
|
@ -170,13 +187,19 @@ class Attention(keras.layers.Layer):
|
||||||
# self.apply_score_masking(score, mask)
|
# self.apply_score_masking(score, mask)
|
||||||
# attn_weights shape == (batch_size, max_length, 1)
|
# attn_weights shape == (batch_size, max_length, 1)
|
||||||
|
|
||||||
attn_weights = self.norm_func(score)
|
# normalize attention scores
|
||||||
|
scores_norm = self.norm_func(score)
|
||||||
|
attn_weights = scores_norm
|
||||||
|
|
||||||
# update attention states
|
# apply forward attention
|
||||||
if self.use_loc_attn:
|
new_alpha = None
|
||||||
states = (states[0] + attn_weights, attn_weights)
|
if self.use_forward_attn:
|
||||||
else:
|
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
||||||
states = ()
|
attn_weights = new_alpha
|
||||||
|
|
||||||
|
# update states tuple
|
||||||
|
# states = (cum_attn_weights, attn_weights, new_alpha)
|
||||||
|
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||||
|
|
||||||
# context_vector shape after sum == (batch_size, hidden_size)
|
# context_vector shape after sum == (batch_size, hidden_size)
|
||||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||||
|
|
Loading…
Reference in New Issue