mirror of https://github.com/coqui-ai/TTS.git
build inference graph for tf models, and update naming convention for TF models
This commit is contained in:
parent
07d2d28ae6
commit
52473d4853
|
@ -26,7 +26,7 @@ parser.add_argument('--config_path',
|
||||||
help='Path to config file of torch model.')
|
help='Path to config file of torch model.')
|
||||||
parser.add_argument('--output_path',
|
parser.add_argument('--output_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='path to save TF model weights.')
|
help='path to output file including file name to save TF model.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load model config
|
# load model config
|
||||||
|
@ -65,18 +65,18 @@ model_tf = Tacotron2(num_chars=num_chars,
|
||||||
# TODO: set layer names so that we can remove these manual matching
|
# TODO: set layer names so that we can remove these manual matching
|
||||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||||
var_map = [
|
var_map = [
|
||||||
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
|
('embedding/embeddings:0', 'embedding.weight'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
||||||
'encoder.lstm.weight_ih_l0'),
|
'encoder.lstm.weight_ih_l0'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
||||||
'encoder.lstm.weight_hh_l0'),
|
'encoder.lstm.weight_hh_l0'),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
||||||
'encoder.lstm.weight_ih_l0_reverse'),
|
'encoder.lstm.weight_ih_l0_reverse'),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
||||||
'encoder.lstm.weight_hh_l0_reverse'),
|
'encoder.lstm.weight_hh_l0_reverse'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
||||||
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
||||||
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
||||||
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
||||||
('decoder/linear_projection/kernel:0',
|
('decoder/linear_projection/kernel:0',
|
||||||
|
@ -86,8 +86,7 @@ var_map = [
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# get tf_model graph
|
# get tf_model graph
|
||||||
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
|
mel_pred = model_tf.build_inference()
|
||||||
mel_pred = model_tf(input_ids, training=False)
|
|
||||||
|
|
||||||
# get tf variables
|
# get tf variables
|
||||||
tf_vars = model_tf.weights
|
tf_vars = model_tf.weights
|
||||||
|
|
|
@ -109,12 +109,17 @@ class Attention(keras.layers.Layer):
|
||||||
raise ValueError("Unknown value for attention norm type")
|
raise ValueError("Unknown value for attention norm type")
|
||||||
|
|
||||||
def init_states(self, batch_size, value_length):
|
def init_states(self, batch_size, value_length):
|
||||||
states = ()
|
states = []
|
||||||
if self.use_loc_attn:
|
if self.use_loc_attn:
|
||||||
attention_cum = tf.zeros([batch_size, value_length])
|
attention_cum = tf.zeros([batch_size, value_length])
|
||||||
attention_old = tf.zeros([batch_size, value_length])
|
attention_old = tf.zeros([batch_size, value_length])
|
||||||
states = (attention_cum, attention_old)
|
states = [attention_cum, attention_old]
|
||||||
return states
|
if self.use_forward_attn:
|
||||||
|
alpha = tf.concat(
|
||||||
|
[tf.ones([batch_size, 1]),
|
||||||
|
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1)
|
||||||
|
states.append(alpha)
|
||||||
|
return tuple(states)
|
||||||
|
|
||||||
def process_values(self, values):
|
def process_values(self, values):
|
||||||
""" cache values for decoder iterations """
|
""" cache values for decoder iterations """
|
||||||
|
@ -125,7 +130,7 @@ class Attention(keras.layers.Layer):
|
||||||
def get_loc_attn(self, query, states):
|
def get_loc_attn(self, query, states):
|
||||||
""" compute location attention, query layer and
|
""" compute location attention, query layer and
|
||||||
unnorm. attention weights"""
|
unnorm. attention weights"""
|
||||||
attention_cum, attention_old = states
|
attention_cum, attention_old = states[:2]
|
||||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||||
|
|
||||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||||
|
|
|
@ -79,8 +79,8 @@ class Decoder(keras.layers.Layer):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim],
|
[self.prenet_dim, self.prenet_dim],
|
||||||
bias=False,
|
bias=False,
|
||||||
name='prenet')
|
name=f'prenet')
|
||||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
|
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'attention_rnn', )
|
||||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||||
|
|
||||||
# TODO: implement other attn options
|
# TODO: implement other attn options
|
||||||
|
@ -94,10 +94,10 @@ class Decoder(keras.layers.Layer):
|
||||||
use_trans_agent=use_trans_agent,
|
use_trans_agent=use_trans_agent,
|
||||||
use_forward_attn_mask=use_forward_attn_mask,
|
use_forward_attn_mask=use_forward_attn_mask,
|
||||||
name='attention')
|
name='attention')
|
||||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
|
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'decoder_rnn')
|
||||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
|
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'linear_projection/linear_layer')
|
||||||
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
|
self.stopnet = keras.layers.Dense(1, name=f'stopnet/linear_layer')
|
||||||
|
|
||||||
|
|
||||||
def set_max_decoder_steps(self, new_max_steps):
|
def set_max_decoder_steps(self, new_max_steps):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
|
||||||
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||||
|
@ -48,9 +49,11 @@ class Tacotron2(keras.models.Model):
|
||||||
use_location_attn=location_attn,
|
use_location_attn=location_attn,
|
||||||
attn_K=attn_K,
|
attn_K=attn_K,
|
||||||
separate_stopnet=separate_stopnet,
|
separate_stopnet=separate_stopnet,
|
||||||
speaker_emb_dim=self.speaker_embed_dim)
|
speaker_emb_dim=self.speaker_embed_dim,
|
||||||
|
name='decoder')
|
||||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||||
|
|
||||||
|
@tf.function(experimental_relax_shapes=True)
|
||||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||||
if training:
|
if training:
|
||||||
return self.training(characters, text_lengths, frames)
|
return self.training(characters, text_lengths, frames)
|
||||||
|
@ -79,3 +82,7 @@ class Tacotron2(keras.models.Model):
|
||||||
print(output_frames.shape)
|
print(output_frames.shape)
|
||||||
return decoder_frames, output_frames, attentions, stop_tokens
|
return decoder_frames, output_frames, attentions, stop_tokens
|
||||||
|
|
||||||
|
def build_inference(self, ):
|
||||||
|
input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
|
||||||
|
self(input_ids)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue