diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 46f919dc..9064811a 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -248,15 +248,11 @@ def synthesis( style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = text_inputs.unsqueeze(0) - elif backend == "tf": + elif backend in ["tf", "tflite"]: # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) text_inputs = numpy_to_tf(text_inputs, tf.int32) text_inputs = tf.expand_dims(text_inputs, 0) - elif backend == "tflite": - style_mel = numpy_to_tf(style_mel, tf.float32) - text_inputs = numpy_to_tf(text_inputs, tf.int32) - text_inputs = tf.expand_dims(text_inputs, 0) # synthesize voice if backend == "torch": outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)