build inference graph for tf models, and update naming convention for TF models

This commit is contained in:
erogol 2020-07-08 10:22:46 +02:00
parent 07d2d28ae6
commit 52473d4853
4 changed files with 31 additions and 20 deletions

View File

@ -26,7 +26,7 @@ parser.add_argument('--config_path',
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to save TF model weights.')
help='path to output file including file name to save TF model.')
args = parser.parse_args()
# load model config
@ -65,18 +65,18 @@ model_tf = Tacotron2(num_chars=num_chars,
# TODO: set layer names so that we can remove these manual matching
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
var_map = [
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
('embedding/embeddings:0', 'embedding.weight'),
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
'encoder.lstm.weight_ih_l0'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
'encoder.lstm.weight_hh_l0'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
'encoder.lstm.weight_ih_l0_reverse'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
'encoder.lstm.weight_hh_l0_reverse'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
('decoder/linear_projection/kernel:0',
@ -86,8 +86,7 @@ var_map = [
# %%
# get tf_model graph
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
mel_pred = model_tf(input_ids, training=False)
mel_pred = model_tf.build_inference()
# get tf variables
tf_vars = model_tf.weights

View File

@ -109,12 +109,17 @@ class Attention(keras.layers.Layer):
raise ValueError("Unknown value for attention norm type")
def init_states(self, batch_size, value_length):
states = ()
states = []
if self.use_loc_attn:
attention_cum = tf.zeros([batch_size, value_length])
attention_old = tf.zeros([batch_size, value_length])
states = (attention_cum, attention_old)
return states
states = [attention_cum, attention_old]
if self.use_forward_attn:
alpha = tf.concat(
[tf.ones([batch_size, 1]),
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1)
states.append(alpha)
return tuple(states)
def process_values(self, values):
""" cache values for decoder iterations """
@ -125,7 +130,7 @@ class Attention(keras.layers.Layer):
def get_loc_attn(self, query, states):
""" compute location attention, query layer and
unnorm. attention weights"""
attention_cum, attention_old = states
attention_cum, attention_old = states[:2]
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1))

View File

@ -79,8 +79,8 @@ class Decoder(keras.layers.Layer):
prenet_dropout,
[self.prenet_dim, self.prenet_dim],
bias=False,
name='prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
name=f'prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'attention_rnn', )
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
@ -94,10 +94,10 @@ class Decoder(keras.layers.Layer):
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name='attention')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'decoder_rnn')
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name=f'stopnet/linear_layer')
def set_max_decoder_steps(self, new_max_steps):

View File

@ -1,3 +1,4 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
@ -48,9 +49,11 @@ class Tacotron2(keras.models.Model):
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim)
speaker_emb_dim=self.speaker_embed_dim,
name='decoder')
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
@tf.function(experimental_relax_shapes=True)
def call(self, characters, text_lengths=None, frames=None, training=None):
if training:
return self.training(characters, text_lengths, frames)
@ -79,3 +82,7 @@ class Tacotron2(keras.models.Model):
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(self, ):
input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
self(input_ids)