lint updates

This commit is contained in:
erogol 2020-07-09 14:22:45 +02:00
parent 26e0b61492
commit 67735cbdac
8 changed files with 43 additions and 35 deletions

View File

@ -1,10 +1,9 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary
import tensorflow as tf
import argparse
from TTS.utils.io import load_config
from TTS.utils.text.symbols import symbols, phonemes, make_symbols
from TTS.utils.text.symbols import symbols, phonemes
from TTS.tf.utils.generic_utils import setup_model
from TTS.tf.utils.io import load_checkpoint
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite
@ -35,10 +34,6 @@ model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)
# create tflite model
tflite_model = convert_tacotron2_to_tflite(model)
# save tflite binary
with open(args.output_path, 'wb') as f:
f.write(tflite_model)
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')

View File

@ -115,9 +115,10 @@ class Attention(keras.layers.Layer):
attention_old = tf.zeros([batch_size, value_length])
states = [attention_cum, attention_old]
if self.use_forward_attn:
alpha = tf.concat(
[tf.ones([batch_size, 1]),
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1)
alpha = tf.concat([
tf.ones([batch_size, 1]),
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
], 1)
states.append(alpha)
return tuple(states)
@ -155,9 +156,9 @@ class Attention(keras.layers.Layer):
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def apply_forward_attention(self, alignment, alpha):
def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use
# forward attention
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)))
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
# compute transition potentials
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
# renormalize attention weights

View File

@ -1,4 +1,3 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tf.utils.tf_utils import shape_list
@ -83,8 +82,8 @@ class Decoder(keras.layers.Layer):
prenet_dropout,
[self.prenet_dim, self.prenet_dim],
bias=False,
name=f'prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'attention_rnn', )
name='prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', )
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
@ -98,10 +97,10 @@ class Decoder(keras.layers.Layer):
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name='attention')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'decoder_rnn')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn')
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name=f'stopnet/linear_layer')
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer')
def set_max_decoder_steps(self, new_max_steps):
@ -263,9 +262,9 @@ class Decoder(keras.layers.Layer):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, _ = self.step(prenet_next,
states,
None,
training=False)
states,
None,
training=False)
stop_token = tf.math.sigmoid(stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
@ -286,8 +285,8 @@ class Decoder(keras.layers.Layer):
outputs = outputs.stack()
outputs = tf.gather(outputs, tf.range(step_count))
outputs = tf.expand_dims(outputs, [0])
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
outputs = tf.expand_dims(outputs, axis=[0])
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
return outputs, stop_tokens, attentions

View File

@ -102,6 +102,6 @@ class Tacotron2(keras.models.Model):
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(self, ):
input_ids = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
input_ids = tf.random.uniform([1, 4], 10, tf.int32)
self(input_ids)

View File

@ -1,16 +1,26 @@
import tensorflow as tf
def convert_tacotron2_to_tflite(model):
tacotron2_concrete_function = model.inference_tflite.get_concrete_function()
def convert_tacotron2_to_tflite(model,
output_path=None,
experimental_converter=True):
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
provided, else return TFLite model."""
concrete_function = model.inference_tflite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tacotron2_concrete_function]
)
converter.experimental_new_converter = True
[concrete_function])
converter.experimental_new_converter = experimental_converter
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, 'wb') as f:
f.write(tflite_model)
return None
return tflite_model

View File

@ -84,7 +84,7 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
model.resize_tensor_input(input_details[0]['index'], inputs.shape)
model.allocate_tensors()
detail = input_details[0]
input_shape = detail['shape']
# input_shape = detail['shape']
model.set_tensor(detail['index'], inputs)
# run the model
model.invoke()

View File

@ -8,7 +8,9 @@ import tensorflow as tf
from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d
class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancestors
#pylint: disable=too-many-ancestors
#pylint: disable=abstract-method
class MelganGenerator(tf.keras.models.Model):
""" Melgan Generator TF implementation dedicated for inference with no
weight norm """
def __init__(self,

View File

@ -3,8 +3,9 @@ import tensorflow as tf
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
from TTS.vocoder.tf.layers.pqmf import PQMF
class MultibandMelganGenerator(MelganGenerator): # pylint: disable=too-many-ancestors
#pylint: disable=too-many-ancestors
#pylint: disable=abstract-method
class MultibandMelganGenerator(MelganGenerator):
def __init__(self,
in_channels=80,
out_channels=4,