mirror of https://github.com/coqui-ai/TTS.git
tf-lite updates for Tacotron2
This commit is contained in:
parent
963ffbd003
commit
d41cb7fe47
|
@ -58,12 +58,16 @@ class Decoder(keras.layers.Layer):
|
|||
#pylint: disable=unused-argument
|
||||
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
||||
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
|
||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs):
|
||||
super(Decoder, self).__init__(**kwargs)
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||
self.r = tf.constant(r, dtype=tf.int32)
|
||||
self.output_dim = r * self.frame_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
# layer constants
|
||||
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
||||
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
||||
|
||||
|
@ -105,6 +109,7 @@ class Decoder(keras.layers.Layer):
|
|||
|
||||
def set_r(self, new_r):
|
||||
self.r = tf.constant(new_r, dtype=tf.int32)
|
||||
self.output_dim = self.frame_dim * new_r
|
||||
|
||||
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
||||
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
||||
|
@ -183,6 +188,7 @@ class Decoder(keras.layers.Layer):
|
|||
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
|
@ -226,7 +232,70 @@ class Decoder(keras.layers.Layer):
|
|||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference_tflite(self, memory, states):
|
||||
"""Inference with TF-Lite compatibility. It assumes
|
||||
batch_size is 1"""
|
||||
# init states
|
||||
# dynamic_shape is not supported in TFLite
|
||||
outputs = tf.TensorArray(dtype=tf.float32,
|
||||
size=self.max_decoder_steps,
|
||||
element_shape=tf.TensorShape(
|
||||
[self.output_dim]),
|
||||
clear_after_read=False,
|
||||
dynamic_size=False)
|
||||
# stop_flags = tf.TensorArray(dtype=tf.bool,
|
||||
# size=self.max_decoder_steps,
|
||||
# element_shape=tf.TensorShape(
|
||||
# []),
|
||||
# clear_after_read=False,
|
||||
# dynamic_size=False)
|
||||
attentions = ()
|
||||
stop_tokens = ()
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, _ = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
|
||||
|
||||
outputs = outputs.write(step, tf.reshape(output, [-1]))
|
||||
return step + 1, memory, states, outputs, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
step_count, memory, states, outputs, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
|
||||
|
||||
outputs = outputs.stack()
|
||||
outputs = tf.gather(outputs, tf.range(step_count))
|
||||
outputs = tf.expand_dims(outputs, [0])
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
|
||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||
if training:
|
||||
return self.decode(memory, states, frames, memory_seq_length)
|
||||
if self.enable_tflite:
|
||||
return self.decode_inference_tflite(memory, states)
|
||||
return self.decode_inference(memory, states)
|
||||
|
|
|
@ -24,7 +24,8 @@ class Tacotron2(keras.models.Model):
|
|||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
bidirectional_decoder=False,
|
||||
enable_tflite=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
|
@ -32,6 +33,7 @@ class Tacotron2(keras.models.Model):
|
|||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.num_speakers = num_speakers
|
||||
self.speaker_embed_dim = 256
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
||||
self.encoder = Encoder(512, name='encoder')
|
||||
|
@ -50,7 +52,8 @@ class Tacotron2(keras.models.Model):
|
|||
attn_K=attn_K,
|
||||
separate_stopnet=separate_stopnet,
|
||||
speaker_emb_dim=self.speaker_embed_dim,
|
||||
name='decoder')
|
||||
name='decoder',
|
||||
enable_tflite=enable_tflite)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
|
@ -82,6 +85,22 @@ class Tacotron2(keras.models.Model):
|
|||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
@tf.function(
|
||||
experimental_relax_shapes=True,
|
||||
input_signature=[
|
||||
tf.TensorSpec([1, None], dtype=tf.int32),
|
||||
],)
|
||||
def inference_tflite(self, characters):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=False)
|
||||
encoder_output = self.encoder(embedding_vectors, training=False)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
||||
postnet_frames = self.postnet(decoder_frames, training=False)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def build_inference(self, ):
|
||||
input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
|
||||
self(input_ids)
|
||||
|
|
|
@ -76,7 +76,7 @@ def count_parameters(model, c):
|
|||
return model.count_params()
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c):
|
||||
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
|
@ -99,5 +99,6 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
enable_tflite=enable_tflite)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue