From 4396f8e2da35796ffb2598726fbbcf45b13fd74b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 3 Mar 2021 15:40:53 +0100 Subject: [PATCH] continue refactoring --- TTS/tts/layers/tacotron/__init__.py | 0 TTS/tts/layers/{ => tacotron}/attentions.py | 2 +- .../layers/{ => tacotron}/common_layers.py | 0 TTS/tts/layers/{ => tacotron}/gst_layers.py | 0 TTS/tts/layers/{ => tacotron}/tacotron.py | 0 TTS/tts/layers/{ => tacotron}/tacotron2.py | 0 TTS/tts/tf/layers/common_layers.py | 288 ----------------- TTS/tts/tf/layers/tacotron/__init__.py | 0 TTS/tts/tf/layers/tacotron2.py | 302 ------------------ 9 files changed, 1 insertion(+), 591 deletions(-) create mode 100644 TTS/tts/layers/tacotron/__init__.py rename TTS/tts/layers/{ => tacotron}/attentions.py (99%) rename TTS/tts/layers/{ => tacotron}/common_layers.py (100%) rename TTS/tts/layers/{ => tacotron}/gst_layers.py (100%) rename TTS/tts/layers/{ => tacotron}/tacotron.py (100%) rename TTS/tts/layers/{ => tacotron}/tacotron2.py (100%) delete mode 100644 TTS/tts/tf/layers/common_layers.py create mode 100644 TTS/tts/tf/layers/tacotron/__init__.py delete mode 100644 TTS/tts/tf/layers/tacotron2.py diff --git a/TTS/tts/layers/tacotron/__init__.py b/TTS/tts/layers/tacotron/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/attentions.py b/TTS/tts/layers/tacotron/attentions.py similarity index 99% rename from TTS/tts/layers/attentions.py rename to TTS/tts/layers/tacotron/attentions.py index f7c720a7..be35deb8 100644 --- a/TTS/tts/layers/attentions.py +++ b/TTS/tts/layers/tacotron/attentions.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn import functional as F -from TTS.tts.layers.common_layers import Linear +from TTS.tts.layers.tacotron.common_layers import Linear from scipy.stats import betabinom diff --git a/TTS/tts/layers/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py similarity index 100% rename from TTS/tts/layers/common_layers.py rename to TTS/tts/layers/tacotron/common_layers.py diff --git a/TTS/tts/layers/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py similarity index 100% rename from TTS/tts/layers/gst_layers.py rename to TTS/tts/layers/tacotron/gst_layers.py diff --git a/TTS/tts/layers/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py similarity index 100% rename from TTS/tts/layers/tacotron.py rename to TTS/tts/layers/tacotron/tacotron.py diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py similarity index 100% rename from TTS/tts/layers/tacotron2.py rename to TTS/tts/layers/tacotron/tacotron2.py diff --git a/TTS/tts/tf/layers/common_layers.py b/TTS/tts/tf/layers/common_layers.py deleted file mode 100644 index ad18b9fc..00000000 --- a/TTS/tts/tf/layers/common_layers.py +++ /dev/null @@ -1,288 +0,0 @@ -import tensorflow as tf -from tensorflow import keras -from tensorflow.python.ops import math_ops -# from tensorflow_addons.seq2seq import BahdanauAttention - -# NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg - -class Linear(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super(Linear, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') - self.activation = keras.layers.ReLU() - - def call(self, x): - """ - shapes: - x: B x T x C - """ - return self.activation(self.linear_layer(x)) - - -class LinearBN(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super(LinearBN, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') - self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') - self.activation = keras.layers.ReLU() - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - out = self.linear_layer(x) - out = self.batch_normalization(out, training=training) - return self.activation(out) - - -class Prenet(keras.layers.Layer): - def __init__(self, - prenet_type, - prenet_dropout, - units, - bias, - **kwargs): - super(Prenet, self).__init__(**kwargs) - self.prenet_type = prenet_type - self.prenet_dropout = prenet_dropout - self.linear_layers = [] - if prenet_type == "bn": - self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] - elif prenet_type == "original": - self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] - else: - raise RuntimeError(' [!] Unknown prenet type.') - if prenet_dropout: - self.dropout = keras.layers.Dropout(rate=0.5) - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - for linear in self.linear_layers: - if self.prenet_dropout: - x = self.dropout(linear(x), training=training) - else: - x = linear(x) - return x - - -def _sigmoid_norm(score): - attn_weights = tf.nn.sigmoid(score) - attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) - return attn_weights - - -class Attention(keras.layers.Layer): - """TODO: implement forward_attention - TODO: location sensitive attention - TODO: implement attention windowing """ - def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, - loc_attn_kernel_size, use_windowing, norm, use_forward_attn, - use_trans_agent, use_forward_attn_mask, **kwargs): - super(Attention, self).__init__(**kwargs) - self.use_loc_attn = use_loc_attn - self.loc_attn_n_filters = loc_attn_n_filters - self.loc_attn_kernel_size = loc_attn_kernel_size - self.use_windowing = use_windowing - self.norm = norm - self.use_forward_attn = use_forward_attn - self.use_trans_agent = use_trans_agent - self.use_forward_attn_mask = use_forward_attn_mask - self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') - self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') - self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') - if use_loc_attn: - self.location_conv1d = keras.layers.Conv1D( - filters=loc_attn_n_filters, - kernel_size=loc_attn_kernel_size, - padding='same', - use_bias=False, - name='location_layer/location_conv1d') - self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') - if norm == 'softmax': - self.norm_func = tf.nn.softmax - elif norm == 'sigmoid': - self.norm_func = _sigmoid_norm - else: - raise ValueError("Unknown value for attention norm type") - - def init_states(self, batch_size, value_length): - states = [] - if self.use_loc_attn: - attention_cum = tf.zeros([batch_size, value_length]) - attention_old = tf.zeros([batch_size, value_length]) - states = [attention_cum, attention_old] - if self.use_forward_attn: - alpha = tf.concat([ - tf.ones([batch_size, 1]), - tf.zeros([batch_size, value_length])[:, :-1] + 1e-7 - ], 1) - states.append(alpha) - return tuple(states) - - def process_values(self, values): - """ cache values for decoder iterations """ - #pylint: disable=attribute-defined-outside-init - self.processed_values = self.inputs_layer(values) - self.values = values - - def get_loc_attn(self, query, states): - """ compute location attention, query layer and - unnorm. attention weights""" - attention_cum, attention_old = states[:2] - attn_cat = tf.stack([attention_old, attention_cum], axis=2) - - processed_query = self.query_layer(tf.expand_dims(query, 1)) - processed_attn = self.location_dense(self.location_conv1d(attn_cat)) - score = self.v( - tf.nn.tanh(self.processed_values + processed_query + - processed_attn)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def get_attn(self, query): - """ compute query layer and unnormalized attention weights """ - processed_query = self.query_layer(tf.expand_dims(query, 1)) - score = self.v(tf.nn.tanh(self.processed_values + processed_query)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def apply_score_masking(self, score, mask): #pylint: disable=no-self-use - """ ignore sequence paddings """ - padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) - # Bias so padding positions do not contribute to attention distribution. - score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) - return score - - def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use - # forward attention - fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) - # compute transition potentials - new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment - # renormalize attention weights - new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) - return new_alpha - - def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): - states = [] - if self.use_loc_attn: - states = [old_states[0] + scores_norm, attn_weights] - if self.use_forward_attn: - states.append(new_alpha) - return tuple(states) - - def call(self, query, states): - """ - shapes: - query: B x D - """ - if self.use_loc_attn: - score, _ = self.get_loc_attn(query, states) - else: - score, _ = self.get_attn(query) - - # TODO: masking - # if mask is not None: - # self.apply_score_masking(score, mask) - # attn_weights shape == (batch_size, max_length, 1) - - # normalize attention scores - scores_norm = self.norm_func(score) - attn_weights = scores_norm - - # apply forward attention - new_alpha = None - if self.use_forward_attn: - new_alpha = self.apply_forward_attention(attn_weights, states[-1]) - attn_weights = new_alpha - - # update states tuple - # states = (cum_attn_weights, attn_weights, new_alpha) - states = self.update_states(states, scores_norm, attn_weights, new_alpha) - - # context_vector shape after sum == (batch_size, hidden_size) - context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) - context_vector = tf.squeeze(context_vector, axis=1) - return context_vector, attn_weights, states - - -# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): -# dtype = processed_query.dtype -# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] -# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) - - -# class LocationSensitiveAttention(BahdanauAttention): -# def __init__(self, -# units, -# memory=None, -# memory_sequence_length=None, -# normalize=False, -# probability_fn="softmax", -# kernel_initializer="glorot_uniform", -# dtype=None, -# name="LocationSensitiveAttention", -# location_attention_filters=32, -# location_attention_kernel_size=31): - -# super(LocationSensitiveAttention, -# self).__init__(units=units, -# memory=memory, -# memory_sequence_length=memory_sequence_length, -# normalize=normalize, -# probability_fn='softmax', ## parent module default -# kernel_initializer=kernel_initializer, -# dtype=dtype, -# name=name) -# if probability_fn == 'sigmoid': -# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) -# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) -# self.location_dense = keras.layers.Dense(units, use_bias=False) -# # self.v = keras.layers.Dense(1, use_bias=True) - -# def _location_sensitive_score(self, processed_query, keys, processed_loc): -# processed_query = tf.expand_dims(processed_query, 1) -# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) - -# def _location_sensitive(self, alignment_cum, alignment_old): -# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) -# return self.location_dense(self.location_conv(alignment_cat)) - -# def _sigmoid_normalization(self, score): -# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) - -# # def _apply_masking(self, score, mask): -# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) -# # # Bias so padding positions do not contribute to attention distribution. -# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) -# # return score - -# def _calculate_attention(self, query, state): -# alignment_cum, alignment_old = state[:2] -# processed_query = self.query_layer( -# query) if self.query_layer else query -# processed_loc = self._location_sensitive(alignment_cum, alignment_old) -# score = self._location_sensitive_score( -# processed_query, -# self.keys, -# processed_loc) -# alignment = self.probability_fn(score, state) -# alignment_cum = alignment_cum + alignment -# state[0] = alignment_cum -# state[1] = alignment -# return alignment, state - -# def compute_context(self, alignments): -# expanded_alignments = tf.expand_dims(alignments, 1) -# context = tf.matmul(expanded_alignments, self.values) -# context = tf.squeeze(context, [1]) -# return context - -# # def call(self, query, state): -# # alignment, next_state = self._calculate_attention(query, state) -# # return alignment, next_state diff --git a/TTS/tts/tf/layers/tacotron/__init__.py b/TTS/tts/tf/layers/tacotron/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/tf/layers/tacotron2.py b/TTS/tts/tf/layers/tacotron2.py deleted file mode 100644 index 50a766a9..00000000 --- a/TTS/tts/tf/layers/tacotron2.py +++ /dev/null @@ -1,302 +0,0 @@ -import tensorflow as tf -from tensorflow import keras -from TTS.tts.tf.utils.tf_utils import shape_list -from TTS.tts.tf.layers.common_layers import Prenet, Attention - - -# NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg -class ConvBNBlock(keras.layers.Layer): - def __init__(self, filters, kernel_size, activation, **kwargs): - super(ConvBNBlock, self).__init__(**kwargs) - self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') - self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') - self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') - self.activation = keras.layers.Activation(activation, name='activation') - - def call(self, x, training=None): - o = self.convolution1d(x) - o = self.batch_normalization(o, training=training) - o = self.activation(o) - o = self.dropout(o, training=training) - return o - - -class Postnet(keras.layers.Layer): - def __init__(self, output_filters, num_convs, **kwargs): - super(Postnet, self).__init__(**kwargs) - self.convolutions = [] - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) - for idx in range(1, num_convs - 1): - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) - self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - return o - - -class Encoder(keras.layers.Layer): - def __init__(self, output_input_dim, **kwargs): - super(Encoder, self).__init__(**kwargs) - self.convolutions = [] - for idx in range(3): - self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) - self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - o = self.lstm(o) - return o - - -class Decoder(keras.layers.Layer): - #pylint: disable=unused-argument - def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, - prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, - use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): - super(Decoder, self).__init__(**kwargs) - self.frame_dim = frame_dim - self.r_init = tf.constant(r, dtype=tf.int32) - self.r = tf.constant(r, dtype=tf.int32) - self.output_dim = r * self.frame_dim - self.separate_stopnet = separate_stopnet - self.enable_tflite = enable_tflite - - # layer constants - self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) - self.stop_thresh = tf.constant(0.5, dtype=tf.float32) - - # model dimensions - self.query_dim = 1024 - self.decoder_rnn_dim = 1024 - self.prenet_dim = 256 - self.attn_dim = 128 - self.p_attention_dropout = 0.1 - self.p_decoder_dropout = 0.1 - - self.prenet = Prenet(prenet_type, - prenet_dropout, - [self.prenet_dim, self.prenet_dim], - bias=False, - name='prenet') - self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', ) - self.attention_rnn_dropout = keras.layers.Dropout(0.5) - - # TODO: implement other attn options - self.attention = Attention(attn_dim=self.attn_dim, - use_loc_attn=True, - loc_attn_n_filters=32, - loc_attn_kernel_size=31, - use_windowing=False, - norm=attn_norm, - use_forward_attn=use_forward_attn, - use_trans_agent=use_trans_agent, - use_forward_attn_mask=use_forward_attn_mask, - name='attention') - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn') - self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer') - self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer') - - - def set_max_decoder_steps(self, new_max_steps): - self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) - - def set_r(self, new_r): - self.r = tf.constant(new_r, dtype=tf.int32) - self.output_dim = self.frame_dim * new_r - - def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): - zero_frame = tf.zeros([batch_size, self.frame_dim]) - zero_context = tf.zeros([batch_size, memory_dim]) - attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - attention_states = self.attention.init_states(batch_size, memory_length) - return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states - - def step(self, prenet_next, states, - memory_seq_length=None, training=None): - _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states - attention_rnn_input = tf.concat([prenet_next, context_next], -1) - attention_rnn_output, attention_rnn_state = \ - self.attention_rnn(attention_rnn_input, - attention_rnn_state, training=training) - attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) - context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) - decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) - decoder_rnn_output, decoder_rnn_state = \ - self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) - decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) - linear_projection_input = tf.concat([decoder_rnn_output, context], -1) - output_frame = self.linear_projection(linear_projection_input, training=training) - stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) - stopnet_output = self.stopnet(stopnet_input, training=training) - output_frame = output_frame[:, :self.r * self.frame_dim] - states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) - return output_frame, stopnet_output, states, attention - - def decode(self, memory, states, frames, memory_seq_length=None): - B, _, _ = shape_list(memory) - num_iter = shape_list(frames)[1] // self.r - # init states - frame_zero = tf.expand_dims(states[0], 1) - frames = tf.concat([frame_zero, frames], axis=1) - outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) - attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) - # pre-computes - self.attention.process_values(memory) - prenet_output = self.prenet(frames, training=True) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): - prenet_next = prenet_output[:, step] - output, stop_token, states, attention = self.step(prenet_next, - states, - memory_seq_length) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions - _, memory, _, states, outputs, stop_tokens, attentions = \ - tf.while_loop(lambda *arg: True, - _body, - loop_vars=(step_count, memory, prenet_output, - states, outputs, stop_tokens, attentions), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=num_iter) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference(self, memory, states): - B, _, _ = shape_list(memory) - # init states - outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, attention = self.step(prenet_next, - states, - None, - training=False) - stop_token = tf.math.sigmoid(stop_token) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag - - cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_tokens, attentions, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference_tflite(self, memory, states): - """Inference with TF-Lite compatibility. It assumes - batch_size is 1""" - # init states - # dynamic_shape is not supported in TFLite - outputs = tf.TensorArray(dtype=tf.float32, - size=self.max_decoder_steps, - element_shape=tf.TensorShape( - [self.output_dim]), - clear_after_read=False, - dynamic_size=False) - # stop_flags = tf.TensorArray(dtype=tf.bool, - # size=self.max_decoder_steps, - # element_shape=tf.TensorShape( - # []), - # clear_after_read=False, - # dynamic_size=False) - attentions = () - stop_tokens = () - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, _ = self.step(prenet_next, - states, - None, - training=False) - stop_token = tf.math.sigmoid(stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) - - outputs = outputs.write(step, tf.reshape(output, [-1])) - return step + 1, memory, states, outputs, stop_flag - - cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - step_count, memory, states, outputs, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) - - - outputs = outputs.stack() - outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter - outputs = tf.expand_dims(outputs, axis=[0]) - outputs = tf.transpose(outputs, [1, 0, 2]) - outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - - def call(self, memory, states, frames=None, memory_seq_length=None, training=False): - if training: - return self.decode(memory, states, frames, memory_seq_length) - if self.enable_tflite: - return self.decode_inference_tflite(memory, states) - return self.decode_inference(memory, states)