From 52473d4853e23f46b4d1b862bdd56f2982da34a7 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 8 Jul 2020 10:22:46 +0200 Subject: [PATCH] build inference graph for tf models, and update naming convention for TF models --- tf/convert_tacotron2_torch_to_tf.py | 19 +++++++++---------- tf/layers/common_layers.py | 13 +++++++++---- tf/layers/tacotron2.py | 10 +++++----- tf/models/tacotron2.py | 9 ++++++++- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/tf/convert_tacotron2_torch_to_tf.py b/tf/convert_tacotron2_torch_to_tf.py index b1878343..dfc42250 100644 --- a/tf/convert_tacotron2_torch_to_tf.py +++ b/tf/convert_tacotron2_torch_to_tf.py @@ -26,7 +26,7 @@ parser.add_argument('--config_path', help='Path to config file of torch model.') parser.add_argument('--output_path', type=str, - help='path to save TF model weights.') + help='path to output file including file name to save TF model.') args = parser.parse_args() # load model config @@ -65,18 +65,18 @@ model_tf = Tacotron2(num_chars=num_chars, # TODO: set layer names so that we can remove these manual matching common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' var_map = [ - ('tacotron2/embedding/embeddings:0', 'embedding.weight'), - ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', + ('embedding/embeddings:0', 'embedding.weight'), + ('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'), - ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', + ('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'), - ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', + ('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'), - ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', + ('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'), - ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', + ('encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), - ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', + ('encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), ('decoder/linear_projection/kernel:0', @@ -86,8 +86,7 @@ var_map = [ # %% # get tf_model graph -input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs() -mel_pred = model_tf(input_ids, training=False) +mel_pred = model_tf.build_inference() # get tf variables tf_vars = model_tf.weights diff --git a/tf/layers/common_layers.py b/tf/layers/common_layers.py index 995b5490..195acfed 100644 --- a/tf/layers/common_layers.py +++ b/tf/layers/common_layers.py @@ -109,12 +109,17 @@ class Attention(keras.layers.Layer): raise ValueError("Unknown value for attention norm type") def init_states(self, batch_size, value_length): - states = () + states = [] if self.use_loc_attn: attention_cum = tf.zeros([batch_size, value_length]) attention_old = tf.zeros([batch_size, value_length]) - states = (attention_cum, attention_old) - return states + states = [attention_cum, attention_old] + if self.use_forward_attn: + alpha = tf.concat( + [tf.ones([batch_size, 1]), + tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1) + states.append(alpha) + return tuple(states) def process_values(self, values): """ cache values for decoder iterations """ @@ -125,7 +130,7 @@ class Attention(keras.layers.Layer): def get_loc_attn(self, query, states): """ compute location attention, query layer and unnorm. attention weights""" - attention_cum, attention_old = states + attention_cum, attention_old = states[:2] attn_cat = tf.stack([attention_old, attention_cum], axis=2) processed_query = self.query_layer(tf.expand_dims(query, 1)) diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py index c6f1a2cd..e19be84b 100644 --- a/tf/layers/tacotron2.py +++ b/tf/layers/tacotron2.py @@ -79,8 +79,8 @@ class Decoder(keras.layers.Layer): prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, - name='prenet') - self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', ) + name=f'prenet') + self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'attention_rnn', ) self.attention_rnn_dropout = keras.layers.Dropout(0.5) # TODO: implement other attn options @@ -94,10 +94,10 @@ class Decoder(keras.layers.Layer): use_trans_agent=use_trans_agent, use_forward_attn_mask=use_forward_attn_mask, name='attention') - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn') + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'decoder_rnn') self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer') - self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer') + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'linear_projection/linear_layer') + self.stopnet = keras.layers.Dense(1, name=f'stopnet/linear_layer') def set_max_decoder_steps(self, new_max_steps): diff --git a/tf/models/tacotron2.py b/tf/models/tacotron2.py index 101291cf..b9a14e2b 100644 --- a/tf/models/tacotron2.py +++ b/tf/models/tacotron2.py @@ -1,3 +1,4 @@ +import tensorflow as tf from tensorflow import keras from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet @@ -48,9 +49,11 @@ class Tacotron2(keras.models.Model): use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, - speaker_emb_dim=self.speaker_embed_dim) + speaker_emb_dim=self.speaker_embed_dim, + name='decoder') self.postnet = Postnet(postnet_output_dim, 5, name='postnet') + @tf.function(experimental_relax_shapes=True) def call(self, characters, text_lengths=None, frames=None, training=None): if training: return self.training(characters, text_lengths, frames) @@ -79,3 +82,7 @@ class Tacotron2(keras.models.Model): print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens + def build_inference(self, ): + input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) + self(input_ids) +