From d41cb7fe47d3f7fc981e8da25a8e887f4c35384e Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 9 Jul 2020 10:56:35 +0200 Subject: [PATCH] tf-lite updates for Tacotron2 --- tf/layers/tacotron2.py | 71 ++++++++++++++++++++++++++++++++++++++- tf/models/tacotron2.py | 23 +++++++++++-- tf/utils/generic_utils.py | 5 +-- 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py index e19be84b..0ba80107 100644 --- a/tf/layers/tacotron2.py +++ b/tf/layers/tacotron2.py @@ -58,12 +58,16 @@ class Decoder(keras.layers.Layer): #pylint: disable=unused-argument def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, - use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs): + use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): super(Decoder, self).__init__(**kwargs) self.frame_dim = frame_dim self.r_init = tf.constant(r, dtype=tf.int32) self.r = tf.constant(r, dtype=tf.int32) + self.output_dim = r * self.frame_dim self.separate_stopnet = separate_stopnet + self.enable_tflite = enable_tflite + + # layer constants self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) self.stop_thresh = tf.constant(0.5, dtype=tf.float32) @@ -105,6 +109,7 @@ class Decoder(keras.layers.Layer): def set_r(self, new_r): self.r = tf.constant(new_r, dtype=tf.int32) + self.output_dim = self.frame_dim * new_r def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): zero_frame = tf.zeros([batch_size, self.frame_dim]) @@ -183,6 +188,7 @@ class Decoder(keras.layers.Layer): outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + # pre-computes self.attention.process_values(memory) @@ -226,7 +232,70 @@ class Decoder(keras.layers.Layer): outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) return outputs, stop_tokens, attentions + def decode_inference_tflite(self, memory, states): + """Inference with TF-Lite compatibility. It assumes + batch_size is 1""" + # init states + # dynamic_shape is not supported in TFLite + outputs = tf.TensorArray(dtype=tf.float32, + size=self.max_decoder_steps, + element_shape=tf.TensorShape( + [self.output_dim]), + clear_after_read=False, + dynamic_size=False) + # stop_flags = tf.TensorArray(dtype=tf.bool, + # size=self.max_decoder_steps, + # element_shape=tf.TensorShape( + # []), + # clear_after_read=False, + # dynamic_size=False) + attentions = () + stop_tokens = () + + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, _ = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) + + outputs = outputs.write(step, tf.reshape(output, [-1])) + return step + 1, memory, states, outputs, stop_flag + + cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + step_count, memory, states, outputs, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + + outputs = outputs.stack() + outputs = tf.gather(outputs, tf.range(step_count)) + outputs = tf.expand_dims(outputs, [0]) + outputs = tf.transpose(outputs, [1, 0, 2]) + outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def call(self, memory, states, frames=None, memory_seq_length=None, training=False): if training: return self.decode(memory, states, frames, memory_seq_length) + if self.enable_tflite: + return self.decode_inference_tflite(memory, states) return self.decode_inference(memory, states) diff --git a/tf/models/tacotron2.py b/tf/models/tacotron2.py index b9a14e2b..03c64b8b 100644 --- a/tf/models/tacotron2.py +++ b/tf/models/tacotron2.py @@ -24,7 +24,8 @@ class Tacotron2(keras.models.Model): forward_attn_mask=False, location_attn=True, separate_stopnet=True, - bidirectional_decoder=False): + bidirectional_decoder=False, + enable_tflite=False): super(Tacotron2, self).__init__() self.r = r self.decoder_output_dim = decoder_output_dim @@ -32,6 +33,7 @@ class Tacotron2(keras.models.Model): self.bidirectional_decoder = bidirectional_decoder self.num_speakers = num_speakers self.speaker_embed_dim = 256 + self.enable_tflite = enable_tflite self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') self.encoder = Encoder(512, name='encoder') @@ -50,7 +52,8 @@ class Tacotron2(keras.models.Model): attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim, - name='decoder') + name='decoder', + enable_tflite=enable_tflite) self.postnet = Postnet(postnet_output_dim, 5, name='postnet') @tf.function(experimental_relax_shapes=True) @@ -82,6 +85,22 @@ class Tacotron2(keras.models.Model): print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens + @tf.function( + experimental_relax_shapes=True, + input_signature=[ + tf.TensorSpec([1, None], dtype=tf.int32), + ],) + def inference_tflite(self, characters): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=False) + encoder_output = self.encoder(embedding_vectors, training=False) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) + postnet_frames = self.postnet(decoder_frames, training=False) + output_frames = decoder_frames + postnet_frames + print(output_frames.shape) + return decoder_frames, output_frames, attentions, stop_tokens + def build_inference(self, ): input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) self(input_ids) diff --git a/tf/utils/generic_utils.py b/tf/utils/generic_utils.py index 3d385b10..1fea4cbb 100644 --- a/tf/utils/generic_utils.py +++ b/tf/utils/generic_utils.py @@ -76,7 +76,7 @@ def count_parameters(model, c): return model.count_params() -def setup_model(num_chars, num_speakers, c): +def setup_model(num_chars, num_speakers, c, enable_tflite=False): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) @@ -99,5 +99,6 @@ def setup_model(num_chars, num_speakers, c): location_attn=c.location_attn, attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder) + bidirectional_decoder=c.bidirectional_decoder, + enable_tflite=enable_tflite) return model