diff --git a/tf/convert_tacotron2_tflite.py b/tf/convert_tacotron2_tflite.py new file mode 100644 index 00000000..28039376 --- /dev/null +++ b/tf/convert_tacotron2_tflite.py @@ -0,0 +1,44 @@ +# Convert Tensorflow Tacotron2 model to TF-Lite binary + +import tensorflow as tf +import argparse + +from TTS.utils.io import load_config +from TTS.utils.text.symbols import symbols, phonemes, make_symbols +from TTS.tf.utils.generic_utils import setup_model +from TTS.tf.utils.io import load_checkpoint +from TTS.tf.utils.tflite import convert_tacotron2_to_tflite + + +parser = argparse.ArgumentParser() +parser.add_argument('--tf_model', + type=str, + help='Path to target torch model to be converted to TF.') +parser.add_argument('--config_path', + type=str, + help='Path to config file of torch model.') +parser.add_argument('--output_path', + type=str, + help='path to tflite output binary.') +args = parser.parse_args() + +# Set constants +CONFIG = load_config(args.config_path) + +# load the model +c = CONFIG +num_speakers = 0 +num_chars = len(phonemes) if c.use_phonemes else len(symbols) +model = setup_model(num_chars, num_speakers, c, enable_tflite=True) +model.build_inference() +model = load_checkpoint(model, args.tf_model) +model.decoder.set_max_decoder_steps(1000) + +# create tflite model +tflite_model = convert_tacotron2_to_tflite(model) + +# save tflite binary +with open(args.output_path, 'wb') as f: + f.write(tflite_model) + +print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py index e19be84b..0ba80107 100644 --- a/tf/layers/tacotron2.py +++ b/tf/layers/tacotron2.py @@ -58,12 +58,16 @@ class Decoder(keras.layers.Layer): #pylint: disable=unused-argument def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, - use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs): + use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): super(Decoder, self).__init__(**kwargs) self.frame_dim = frame_dim self.r_init = tf.constant(r, dtype=tf.int32) self.r = tf.constant(r, dtype=tf.int32) + self.output_dim = r * self.frame_dim self.separate_stopnet = separate_stopnet + self.enable_tflite = enable_tflite + + # layer constants self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) self.stop_thresh = tf.constant(0.5, dtype=tf.float32) @@ -105,6 +109,7 @@ class Decoder(keras.layers.Layer): def set_r(self, new_r): self.r = tf.constant(new_r, dtype=tf.int32) + self.output_dim = self.frame_dim * new_r def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): zero_frame = tf.zeros([batch_size, self.frame_dim]) @@ -183,6 +188,7 @@ class Decoder(keras.layers.Layer): outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + # pre-computes self.attention.process_values(memory) @@ -226,7 +232,70 @@ class Decoder(keras.layers.Layer): outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) return outputs, stop_tokens, attentions + def decode_inference_tflite(self, memory, states): + """Inference with TF-Lite compatibility. It assumes + batch_size is 1""" + # init states + # dynamic_shape is not supported in TFLite + outputs = tf.TensorArray(dtype=tf.float32, + size=self.max_decoder_steps, + element_shape=tf.TensorShape( + [self.output_dim]), + clear_after_read=False, + dynamic_size=False) + # stop_flags = tf.TensorArray(dtype=tf.bool, + # size=self.max_decoder_steps, + # element_shape=tf.TensorShape( + # []), + # clear_after_read=False, + # dynamic_size=False) + attentions = () + stop_tokens = () + + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, _ = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) + + outputs = outputs.write(step, tf.reshape(output, [-1])) + return step + 1, memory, states, outputs, stop_flag + + cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + step_count, memory, states, outputs, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + + outputs = outputs.stack() + outputs = tf.gather(outputs, tf.range(step_count)) + outputs = tf.expand_dims(outputs, [0]) + outputs = tf.transpose(outputs, [1, 0, 2]) + outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def call(self, memory, states, frames=None, memory_seq_length=None, training=False): if training: return self.decode(memory, states, frames, memory_seq_length) + if self.enable_tflite: + return self.decode_inference_tflite(memory, states) return self.decode_inference(memory, states) diff --git a/tf/models/tacotron2.py b/tf/models/tacotron2.py index b9a14e2b..03c64b8b 100644 --- a/tf/models/tacotron2.py +++ b/tf/models/tacotron2.py @@ -24,7 +24,8 @@ class Tacotron2(keras.models.Model): forward_attn_mask=False, location_attn=True, separate_stopnet=True, - bidirectional_decoder=False): + bidirectional_decoder=False, + enable_tflite=False): super(Tacotron2, self).__init__() self.r = r self.decoder_output_dim = decoder_output_dim @@ -32,6 +33,7 @@ class Tacotron2(keras.models.Model): self.bidirectional_decoder = bidirectional_decoder self.num_speakers = num_speakers self.speaker_embed_dim = 256 + self.enable_tflite = enable_tflite self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') self.encoder = Encoder(512, name='encoder') @@ -50,7 +52,8 @@ class Tacotron2(keras.models.Model): attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim, - name='decoder') + name='decoder', + enable_tflite=enable_tflite) self.postnet = Postnet(postnet_output_dim, 5, name='postnet') @tf.function(experimental_relax_shapes=True) @@ -82,6 +85,22 @@ class Tacotron2(keras.models.Model): print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens + @tf.function( + experimental_relax_shapes=True, + input_signature=[ + tf.TensorSpec([1, None], dtype=tf.int32), + ],) + def inference_tflite(self, characters): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=False) + encoder_output = self.encoder(embedding_vectors, training=False) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) + postnet_frames = self.postnet(decoder_frames, training=False) + output_frames = decoder_frames + postnet_frames + print(output_frames.shape) + return decoder_frames, output_frames, attentions, stop_tokens + def build_inference(self, ): input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) self(input_ids) diff --git a/tf/utils/generic_utils.py b/tf/utils/generic_utils.py index 3d385b10..1fea4cbb 100644 --- a/tf/utils/generic_utils.py +++ b/tf/utils/generic_utils.py @@ -76,7 +76,7 @@ def count_parameters(model, c): return model.count_params() -def setup_model(num_chars, num_speakers, c): +def setup_model(num_chars, num_speakers, c, enable_tflite=False): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) @@ -99,5 +99,6 @@ def setup_model(num_chars, num_speakers, c): location_attn=c.location_attn, attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder) + bidirectional_decoder=c.bidirectional_decoder, + enable_tflite=enable_tflite) return model diff --git a/tf/utils/io.py b/tf/utils/io.py new file mode 100644 index 00000000..78a56de4 --- /dev/null +++ b/tf/utils/io.py @@ -0,0 +1,42 @@ +import pickle +import datetime +import tensorflow as tf + + +def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): + state = { + 'model': model.weights, + 'optimizer': optimizer, + 'step': current_step, + 'epoch': epoch, + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': r + } + state.update(kwargs) + pickle.dump(state, open(output_path, 'wb')) + + +def load_checkpoint(model, checkpoint_path): + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + tf_vars = model.weights + for tf_var in tf_vars: + layer_name = tf_var.name + try: + chkp_var_value = chkp_var_dict[layer_name] + except KeyError: + class_name = list(chkp_var_dict.keys())[0].split("/")[0] + layer_name = f"{class_name}/{layer_name}" + chkp_var_value = chkp_var_dict[layer_name] + + tf.keras.backend.set_value(tf_var, chkp_var_value) + if 'r' in checkpoint.keys(): + model.decoder.set_r(checkpoint['r']) + return model + + +def load_tflite_model(tflite_path): + tflite_model = tf.lite.Interpreter(model_path=tflite_path) + tflite_model.allocate_tensors() + return tflite_model + diff --git a/tf/utils/tflite.py b/tf/utils/tflite.py new file mode 100644 index 00000000..a46c1dce --- /dev/null +++ b/tf/utils/tflite.py @@ -0,0 +1,20 @@ +import tensorflow as tf + + +def convert_tacotron2_to_tflite(model): + tacotron2_concrete_function = model.inference_tflite.get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [tacotron2_concrete_function] + ) + converter.experimental_new_converter = True + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS] + tflite_model = converter.convert() + return tflite_model + + +def load_tflite_model(tflite_path): + tflite_model = tf.lite.Interpreter(model_path=tflite_path) + tflite_model.allocate_tensors() + return tflite_model \ No newline at end of file diff --git a/utils/synthesis.py b/utils/synthesis.py index 03d7072e..056a7b46 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -70,6 +70,31 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No return decoder_output, postnet_output, alignments, stop_tokens +def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): + if CONFIG.use_gst and style_mel is not None: + raise NotImplementedError(' [!] GST inference not implemented for TfLite') + if truncated: + raise NotImplementedError(' [!] Truncated inference not implemented for TfLite') + if speaker_id is not None: + raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite') + # get input and output details + input_details = model.get_input_details() + output_details = model.get_output_details() + # reshape input tensor for the new input shape + model.resize_tensor_input(input_details[0]['index'], inputs.shape) + model.allocate_tensors() + detail = input_details[0] + input_shape = detail['shape'] + model.set_tensor(detail['index'], inputs) + # run the model + model.invoke() + # collect outputs + decoder_output = model.get_tensor(output_details[0]['index']) + postnet_output = model.get_tensor(output_details[1]['index']) + # tflite model only returns feature frames + return decoder_output, postnet_output, None, None + + def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() @@ -86,6 +111,12 @@ def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): return postnet_output, decoder_output, alignment, stop_tokens +def parse_outputs_tflite(postnet_output, decoder_output): + postnet_output = postnet_output[0] + decoder_output = decoder_output[0] + return postnet_output, decoder_output + + def trim_silence(wav, ap): return wav[:ap.find_endpoint(wav)] @@ -164,22 +195,31 @@ def synthesis(model, style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) - else: + elif backend == 'tf': # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) + elif backend == 'tflite': + style_mel = numpy_to_tf(style_mel, tf.float32) + inputs = numpy_to_tf(inputs, tf.int32) + inputs = tf.expand_dims(inputs, 0) # synthesize voice if backend == 'torch': decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( model, inputs, CONFIG, truncated, speaker_id, style_mel) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( postnet_output, decoder_output, alignments, stop_tokens) - else: + elif backend == 'tf': decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( model, inputs, CONFIG, truncated, speaker_id, style_mel) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( postnet_output, decoder_output, alignments, stop_tokens) + elif backend == 'tflite': + decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output = parse_outputs_tflite( + postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None