mirror of https://github.com/coqui-ai/TTS.git
build inference graph for tf models, and update naming convention for TF models
This commit is contained in:
parent
07d2d28ae6
commit
52473d4853
|
@ -26,7 +26,7 @@ parser.add_argument('--config_path',
|
|||
help='Path to config file of torch model.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path to save TF model weights.')
|
||||
help='path to output file including file name to save TF model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
|
@ -65,18 +65,18 @@ model_tf = Tacotron2(num_chars=num_chars,
|
|||
# TODO: set layer names so that we can remove these manual matching
|
||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||
var_map = [
|
||||
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
||||
('embedding/embeddings:0', 'embedding.weight'),
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
||||
'encoder.lstm.weight_ih_l0'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
||||
'encoder.lstm.weight_hh_l0'),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
||||
'encoder.lstm.weight_ih_l0_reverse'),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
||||
'encoder.lstm.weight_hh_l0_reverse'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
||||
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
||||
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
||||
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
||||
('decoder/linear_projection/kernel:0',
|
||||
|
@ -86,8 +86,7 @@ var_map = [
|
|||
|
||||
# %%
|
||||
# get tf_model graph
|
||||
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
|
||||
mel_pred = model_tf(input_ids, training=False)
|
||||
mel_pred = model_tf.build_inference()
|
||||
|
||||
# get tf variables
|
||||
tf_vars = model_tf.weights
|
||||
|
|
|
@ -109,12 +109,17 @@ class Attention(keras.layers.Layer):
|
|||
raise ValueError("Unknown value for attention norm type")
|
||||
|
||||
def init_states(self, batch_size, value_length):
|
||||
states = ()
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
attention_cum = tf.zeros([batch_size, value_length])
|
||||
attention_old = tf.zeros([batch_size, value_length])
|
||||
states = (attention_cum, attention_old)
|
||||
return states
|
||||
states = [attention_cum, attention_old]
|
||||
if self.use_forward_attn:
|
||||
alpha = tf.concat(
|
||||
[tf.ones([batch_size, 1]),
|
||||
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1)
|
||||
states.append(alpha)
|
||||
return tuple(states)
|
||||
|
||||
def process_values(self, values):
|
||||
""" cache values for decoder iterations """
|
||||
|
@ -125,7 +130,7 @@ class Attention(keras.layers.Layer):
|
|||
def get_loc_attn(self, query, states):
|
||||
""" compute location attention, query layer and
|
||||
unnorm. attention weights"""
|
||||
attention_cum, attention_old = states
|
||||
attention_cum, attention_old = states[:2]
|
||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
|
|
|
@ -79,8 +79,8 @@ class Decoder(keras.layers.Layer):
|
|||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim],
|
||||
bias=False,
|
||||
name='prenet')
|
||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
|
||||
name=f'prenet')
|
||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'attention_rnn', )
|
||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
|
||||
# TODO: implement other attn options
|
||||
|
@ -94,10 +94,10 @@ class Decoder(keras.layers.Layer):
|
|||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name='attention')
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'decoder_rnn')
|
||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
|
||||
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'linear_projection/linear_layer')
|
||||
self.stopnet = keras.layers.Dense(1, name=f'stopnet/linear_layer')
|
||||
|
||||
|
||||
def set_max_decoder_steps(self, new_max_steps):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
|
@ -48,9 +49,11 @@ class Tacotron2(keras.models.Model):
|
|||
use_location_attn=location_attn,
|
||||
attn_K=attn_K,
|
||||
separate_stopnet=separate_stopnet,
|
||||
speaker_emb_dim=self.speaker_embed_dim)
|
||||
speaker_emb_dim=self.speaker_embed_dim,
|
||||
name='decoder')
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
if training:
|
||||
return self.training(characters, text_lengths, frames)
|
||||
|
@ -79,3 +82,7 @@ class Tacotron2(keras.models.Model):
|
|||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def build_inference(self, ):
|
||||
input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
|
||||
self(input_ids)
|
||||
|
||||
|
|
Loading…
Reference in New Issue