mirror of https://github.com/coqui-ai/TTS.git
continue refactoring
This commit is contained in:
parent
8f9858cf44
commit
4396f8e2da
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.common_layers import Linear
|
||||
from TTS.tts.layers.tacotron.common_layers import Linear
|
||||
from scipy.stats import betabinom
|
||||
|
||||
|
|
@ -1,288 +0,0 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.python.ops import math_ops
|
||||
# from tensorflow_addons.seq2seq import BahdanauAttention
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
|
||||
class Linear(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(Linear, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
return self.activation(self.linear_layer(x))
|
||||
|
||||
|
||||
class LinearBN(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(LinearBN, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
out = self.linear_layer(x)
|
||||
out = self.batch_normalization(out, training=training)
|
||||
return self.activation(out)
|
||||
|
||||
|
||||
class Prenet(keras.layers.Layer):
|
||||
def __init__(self,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
units,
|
||||
bias,
|
||||
**kwargs):
|
||||
super(Prenet, self).__init__(**kwargs)
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.linear_layers = []
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
elif prenet_type == "original":
|
||||
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
else:
|
||||
raise RuntimeError(' [!] Unknown prenet type.')
|
||||
if prenet_dropout:
|
||||
self.dropout = keras.layers.Dropout(rate=0.5)
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = self.dropout(linear(x), training=training)
|
||||
else:
|
||||
x = linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def _sigmoid_norm(score):
|
||||
attn_weights = tf.nn.sigmoid(score)
|
||||
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
|
||||
return attn_weights
|
||||
|
||||
|
||||
class Attention(keras.layers.Layer):
|
||||
"""TODO: implement forward_attention
|
||||
TODO: location sensitive attention
|
||||
TODO: implement attention windowing """
|
||||
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
|
||||
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
|
||||
use_trans_agent, use_forward_attn_mask, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
self.use_loc_attn = use_loc_attn
|
||||
self.loc_attn_n_filters = loc_attn_n_filters
|
||||
self.loc_attn_kernel_size = loc_attn_kernel_size
|
||||
self.use_windowing = use_windowing
|
||||
self.norm = norm
|
||||
self.use_forward_attn = use_forward_attn
|
||||
self.use_trans_agent = use_trans_agent
|
||||
self.use_forward_attn_mask = use_forward_attn_mask
|
||||
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
|
||||
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
|
||||
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
|
||||
if use_loc_attn:
|
||||
self.location_conv1d = keras.layers.Conv1D(
|
||||
filters=loc_attn_n_filters,
|
||||
kernel_size=loc_attn_kernel_size,
|
||||
padding='same',
|
||||
use_bias=False,
|
||||
name='location_layer/location_conv1d')
|
||||
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
|
||||
if norm == 'softmax':
|
||||
self.norm_func = tf.nn.softmax
|
||||
elif norm == 'sigmoid':
|
||||
self.norm_func = _sigmoid_norm
|
||||
else:
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
|
||||
def init_states(self, batch_size, value_length):
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
attention_cum = tf.zeros([batch_size, value_length])
|
||||
attention_old = tf.zeros([batch_size, value_length])
|
||||
states = [attention_cum, attention_old]
|
||||
if self.use_forward_attn:
|
||||
alpha = tf.concat([
|
||||
tf.ones([batch_size, 1]),
|
||||
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
|
||||
], 1)
|
||||
states.append(alpha)
|
||||
return tuple(states)
|
||||
|
||||
def process_values(self, values):
|
||||
""" cache values for decoder iterations """
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
self.processed_values = self.inputs_layer(values)
|
||||
self.values = values
|
||||
|
||||
def get_loc_attn(self, query, states):
|
||||
""" compute location attention, query layer and
|
||||
unnorm. attention weights"""
|
||||
attention_cum, attention_old = states[:2]
|
||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
|
||||
score = self.v(
|
||||
tf.nn.tanh(self.processed_values + processed_query +
|
||||
processed_attn))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def get_attn(self, query):
|
||||
""" compute query layer and unnormalized attention weights """
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def apply_score_masking(self, score, mask): #pylint: disable=no-self-use
|
||||
""" ignore sequence paddings """
|
||||
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# Bias so padding positions do not contribute to attention distribution.
|
||||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
return score
|
||||
|
||||
def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use
|
||||
# forward attention
|
||||
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
|
||||
# compute transition potentials
|
||||
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
||||
# renormalize attention weights
|
||||
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
||||
return new_alpha
|
||||
|
||||
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
states = [old_states[0] + scores_norm, attn_weights]
|
||||
if self.use_forward_attn:
|
||||
states.append(new_alpha)
|
||||
return tuple(states)
|
||||
|
||||
def call(self, query, states):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D
|
||||
"""
|
||||
if self.use_loc_attn:
|
||||
score, _ = self.get_loc_attn(query, states)
|
||||
else:
|
||||
score, _ = self.get_attn(query)
|
||||
|
||||
# TODO: masking
|
||||
# if mask is not None:
|
||||
# self.apply_score_masking(score, mask)
|
||||
# attn_weights shape == (batch_size, max_length, 1)
|
||||
|
||||
# normalize attention scores
|
||||
scores_norm = self.norm_func(score)
|
||||
attn_weights = scores_norm
|
||||
|
||||
# apply forward attention
|
||||
new_alpha = None
|
||||
if self.use_forward_attn:
|
||||
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
||||
attn_weights = new_alpha
|
||||
|
||||
# update states tuple
|
||||
# states = (cum_attn_weights, attn_weights, new_alpha)
|
||||
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||
|
||||
# context_vector shape after sum == (batch_size, hidden_size)
|
||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||
context_vector = tf.squeeze(context_vector, axis=1)
|
||||
return context_vector, attn_weights, states
|
||||
|
||||
|
||||
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
|
||||
# dtype = processed_query.dtype
|
||||
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
|
||||
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
|
||||
|
||||
|
||||
# class LocationSensitiveAttention(BahdanauAttention):
|
||||
# def __init__(self,
|
||||
# units,
|
||||
# memory=None,
|
||||
# memory_sequence_length=None,
|
||||
# normalize=False,
|
||||
# probability_fn="softmax",
|
||||
# kernel_initializer="glorot_uniform",
|
||||
# dtype=None,
|
||||
# name="LocationSensitiveAttention",
|
||||
# location_attention_filters=32,
|
||||
# location_attention_kernel_size=31):
|
||||
|
||||
# super(LocationSensitiveAttention,
|
||||
# self).__init__(units=units,
|
||||
# memory=memory,
|
||||
# memory_sequence_length=memory_sequence_length,
|
||||
# normalize=normalize,
|
||||
# probability_fn='softmax', ## parent module default
|
||||
# kernel_initializer=kernel_initializer,
|
||||
# dtype=dtype,
|
||||
# name=name)
|
||||
# if probability_fn == 'sigmoid':
|
||||
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
|
||||
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
|
||||
# self.location_dense = keras.layers.Dense(units, use_bias=False)
|
||||
# # self.v = keras.layers.Dense(1, use_bias=True)
|
||||
|
||||
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
|
||||
# processed_query = tf.expand_dims(processed_query, 1)
|
||||
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
|
||||
|
||||
# def _location_sensitive(self, alignment_cum, alignment_old):
|
||||
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
|
||||
# return self.location_dense(self.location_conv(alignment_cat))
|
||||
|
||||
# def _sigmoid_normalization(self, score):
|
||||
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
|
||||
|
||||
# # def _apply_masking(self, score, mask):
|
||||
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# # # Bias so padding positions do not contribute to attention distribution.
|
||||
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
# # return score
|
||||
|
||||
# def _calculate_attention(self, query, state):
|
||||
# alignment_cum, alignment_old = state[:2]
|
||||
# processed_query = self.query_layer(
|
||||
# query) if self.query_layer else query
|
||||
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
|
||||
# score = self._location_sensitive_score(
|
||||
# processed_query,
|
||||
# self.keys,
|
||||
# processed_loc)
|
||||
# alignment = self.probability_fn(score, state)
|
||||
# alignment_cum = alignment_cum + alignment
|
||||
# state[0] = alignment_cum
|
||||
# state[1] = alignment
|
||||
# return alignment, state
|
||||
|
||||
# def compute_context(self, alignments):
|
||||
# expanded_alignments = tf.expand_dims(alignments, 1)
|
||||
# context = tf.matmul(expanded_alignments, self.values)
|
||||
# context = tf.squeeze(context, [1])
|
||||
# return context
|
||||
|
||||
# # def call(self, query, state):
|
||||
# # alignment, next_state = self._calculate_attention(query, state)
|
||||
# # return alignment, next_state
|
|
@ -1,302 +0,0 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||
from TTS.tts.tf.layers.common_layers import Prenet, Attention
|
||||
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
class ConvBNBlock(keras.layers.Layer):
|
||||
def __init__(self, filters, kernel_size, activation, **kwargs):
|
||||
super(ConvBNBlock, self).__init__(**kwargs)
|
||||
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
|
||||
self.activation = keras.layers.Activation(activation, name='activation')
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = self.convolution1d(x)
|
||||
o = self.batch_normalization(o, training=training)
|
||||
o = self.activation(o)
|
||||
o = self.dropout(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Postnet(keras.layers.Layer):
|
||||
def __init__(self, output_filters, num_convs, **kwargs):
|
||||
super(Postnet, self).__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
|
||||
for idx in range(1, num_convs - 1):
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
|
||||
self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(keras.layers.Layer):
|
||||
def __init__(self, output_input_dim, **kwargs):
|
||||
super(Encoder, self).__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
for idx in range(3):
|
||||
self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
|
||||
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
o = self.lstm(o)
|
||||
return o
|
||||
|
||||
|
||||
class Decoder(keras.layers.Layer):
|
||||
#pylint: disable=unused-argument
|
||||
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
||||
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs):
|
||||
super(Decoder, self).__init__(**kwargs)
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||
self.r = tf.constant(r, dtype=tf.int32)
|
||||
self.output_dim = r * self.frame_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
# layer constants
|
||||
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
||||
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
||||
|
||||
# model dimensions
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.attn_dim = 128
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(prenet_type,
|
||||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim],
|
||||
bias=False,
|
||||
name='prenet')
|
||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', )
|
||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
|
||||
# TODO: implement other attn options
|
||||
self.attention = Attention(attn_dim=self.attn_dim,
|
||||
use_loc_attn=True,
|
||||
loc_attn_n_filters=32,
|
||||
loc_attn_kernel_size=31,
|
||||
use_windowing=False,
|
||||
norm=attn_norm,
|
||||
use_forward_attn=use_forward_attn,
|
||||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name='attention')
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn')
|
||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer')
|
||||
self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer')
|
||||
|
||||
|
||||
def set_max_decoder_steps(self, new_max_steps):
|
||||
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
|
||||
|
||||
def set_r(self, new_r):
|
||||
self.r = tf.constant(new_r, dtype=tf.int32)
|
||||
self.output_dim = self.frame_dim * new_r
|
||||
|
||||
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
||||
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
||||
zero_context = tf.zeros([batch_size, memory_dim])
|
||||
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
attention_states = self.attention.init_states(batch_size, memory_length)
|
||||
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
|
||||
|
||||
def step(self, prenet_next, states,
|
||||
memory_seq_length=None, training=None):
|
||||
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
|
||||
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
|
||||
attention_rnn_output, attention_rnn_state = \
|
||||
self.attention_rnn(attention_rnn_input,
|
||||
attention_rnn_state, training=training)
|
||||
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
|
||||
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
|
||||
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
|
||||
decoder_rnn_output, decoder_rnn_state = \
|
||||
self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
|
||||
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
|
||||
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
|
||||
output_frame = self.linear_projection(linear_projection_input, training=training)
|
||||
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
|
||||
stopnet_output = self.stopnet(stopnet_input, training=training)
|
||||
output_frame = output_frame[:, :self.r * self.frame_dim]
|
||||
states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
|
||||
return output_frame, stopnet_output, states, attention
|
||||
|
||||
def decode(self, memory, states, frames, memory_seq_length=None):
|
||||
B, _, _ = shape_list(memory)
|
||||
num_iter = shape_list(frames)[1] // self.r
|
||||
# init states
|
||||
frame_zero = tf.expand_dims(states[0], 1)
|
||||
frames = tf.concat([frame_zero, frames], axis=1)
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
prenet_output = self.prenet(frames, training=True)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
|
||||
prenet_next = prenet_output[:, step]
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
memory_seq_length)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
|
||||
_, memory, _, states, outputs, stop_tokens, attentions = \
|
||||
tf.while_loop(lambda *arg: True,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, prenet_output,
|
||||
states, outputs, stop_tokens, attentions),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=num_iter)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0, 2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference(self, memory, states):
|
||||
B, _, _ = shape_list(memory)
|
||||
# init states
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_tokens, attentions, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0, 2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference_tflite(self, memory, states):
|
||||
"""Inference with TF-Lite compatibility. It assumes
|
||||
batch_size is 1"""
|
||||
# init states
|
||||
# dynamic_shape is not supported in TFLite
|
||||
outputs = tf.TensorArray(dtype=tf.float32,
|
||||
size=self.max_decoder_steps,
|
||||
element_shape=tf.TensorShape(
|
||||
[self.output_dim]),
|
||||
clear_after_read=False,
|
||||
dynamic_size=False)
|
||||
# stop_flags = tf.TensorArray(dtype=tf.bool,
|
||||
# size=self.max_decoder_steps,
|
||||
# element_shape=tf.TensorShape(
|
||||
# []),
|
||||
# clear_after_read=False,
|
||||
# dynamic_size=False)
|
||||
attentions = ()
|
||||
stop_tokens = ()
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, _ = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
|
||||
|
||||
outputs = outputs.write(step, tf.reshape(output, [-1]))
|
||||
return step + 1, memory, states, outputs, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
step_count, memory, states, outputs, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
|
||||
|
||||
outputs = outputs.stack()
|
||||
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
|
||||
outputs = tf.expand_dims(outputs, axis=[0])
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
|
||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||
if training:
|
||||
return self.decode(memory, states, frames, memory_seq_length)
|
||||
if self.enable_tflite:
|
||||
return self.decode_inference_tflite(memory, states)
|
||||
return self.decode_inference(memory, states)
|
Loading…
Reference in New Issue