diff --git a/README.md b/README.md
index ce496707..31bfe73c 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,14 @@

-
+
+
+
+
-This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en). TTS aims a deep learning based Text2Speech engine, low in cost and high in quality.
+This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en).
+
+Mozilla TTS aims a deep learning based Text2Speech engine, low in cost and high in quality.
You can check some of synthesized voice samples from [here](https://erogol.github.io/ddc-samples/).
@@ -38,25 +43,26 @@ Vocoders:
You can also help us implement more models. Some TTS related work can be found [here](https://github.com/erogol/TTS-papers).
## Features
-- High performance Deep Learning models for Text2Speech related tasks.
- - Text2Speech models (Tacotron, Tacotron2).
+- High performance Deep Learning models for Text2Speech tasks.
+ - Text2Spec models (Tacotron, Tacotron2).
- Speaker Encoder to compute speaker embeddings efficiently.
- - Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS)
-- Support for multi-speaker TTS training.
-- Support for Multi-GPUs training.
-- Ability to convert Torch models to Tensorflow 2.0 for inference.
-- Released pre-trained models.
+ - Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN)
- Fast and efficient model training.
- Detailed training logs on console and Tensorboard.
+- Support for multi-speaker TTS.
+- Efficient Multi-GPUs training.
+- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
+- Released models in PyTorch, Tensorflow and TFLite.
- Tools to curate Text2Speech datasets under```dataset_analysis```.
- Demo server for model testing.
- Notebooks for extensive model benchmarking.
- Modular (but not too much) code base enabling easy testing for new ideas.
-## Requirements and Installation
+## Main Requirements and Installation
Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easier installation.
* python>=3.6
- * pytorch>=0.4.1
+ * pytorch>=1.4.1
+ * tensorflow>=2.2
* librosa
* tensorboard
* tensorboardX
@@ -107,26 +113,14 @@ Audio examples: [soundcloud](https://soundcloud.com/user-565970875/pocket-articl
-## Runtime
-The most time-consuming part is the vocoder algorithm (Griffin-Lim) which runs on CPU. By setting its number of iterations lower, you might have faster execution with a small loss of quality. Some of the experimental values are below.
-
-Sentence: "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."
-
-Audio length is approximately 6 secs.
-
-| Time (secs) | System | # GL iters | Model
-| ---- |:-------|:-----------| ---- |
-|2.00|GTX1080Ti|30|Tacotron|
-|3.01|GTX1080Ti|60|Tacotron|
-|3.57|CPU|60|Tacotron|
-|5.27|GTX1080Ti|60|Tacotron2|
-|6.50|CPU|60|Tacotron2|
-
+## [Mozilla TTS Tutorials and Notebooks](https://github.com/mozilla/TTS/wiki/TTS-Notebooks-and-Tutorials)
## Datasets and Data-Loading
-TTS provides a generic dataloader easy to use for new datasets. You need to write an preprocessor function to integrate your own dataset.Check ```datasets/preprocess.py``` to see some examples. After the function, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields too.
+TTS provides a generic dataloader easy to use for your custom dataset.
+You just need to write a simple function to format the dataset. Check ```datasets/preprocess.py``` to see some examples.
+After that, you need to set ```dataset``` fields in ```config.json```.
-Some of the open-sourced datasets that we successfully applied TTS, are linked below.
+Some of the public datasets that we successfully applied TTS:
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
- [Nancy](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/)
@@ -164,8 +158,6 @@ In case of any error or intercepted execution, if there is no checkpoint yet und
You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder.
-## [Testing and Examples](https://github.com/mozilla/TTS/wiki/Examples-using-TTS)
-
## Contribution guidelines
This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/)
diff --git a/requirements.txt b/requirements.txt
index 959fe2d7..ec7a1092 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,7 +15,10 @@ tqdm
inflect
pysbd
bokeh==1.4.0
+pysbd
soundfile
nose==1.3.7
cardboardlint==1.3.0
pylint==2.5.3
+fuzzywuzzy
+gdown
diff --git a/server/synthesizer.py b/server/synthesizer.py
index 0c402609..0f743d87 100644
--- a/server/synthesizer.py
+++ b/server/synthesizer.py
@@ -1,5 +1,4 @@
import io
-import re
import sys
import time
diff --git a/setup.py b/setup.py
index b139dc20..3f02dd09 100644
--- a/setup.py
+++ b/setup.py
@@ -93,11 +93,14 @@ requirements = {
"inflect",
"pysbd",
"bokeh==1.4.0",
+ "pysbd",
"soundfile",
"phonemizer>=2.2.0",
"nose==1.3.7",
"cardboardlint==1.3.0",
"pylint==2.5.3",
+ 'fuzzywuzzy',
+ 'gdown'
],
'pip_install':[
'tensorflow>=2.2.0',
diff --git a/tf/convert_tacotron2_tflite.py b/tf/convert_tacotron2_tflite.py
index e06cac2b..fc46cc79 100644
--- a/tf/convert_tacotron2_tflite.py
+++ b/tf/convert_tacotron2_tflite.py
@@ -34,6 +34,4 @@ model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)
# create tflite model
-tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
-
-print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
+tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
\ No newline at end of file
diff --git a/tf/utils/tflite.py b/tf/utils/tflite.py
index 6c37f170..5e684b30 100644
--- a/tf/utils/tflite.py
+++ b/tf/utils/tflite.py
@@ -16,6 +16,7 @@ def convert_tacotron2_to_tflite(model,
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
+ print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, 'wb') as f:
diff --git a/vocoder/tf/convert_melgan_tflite.py b/vocoder/tf/convert_melgan_tflite.py
new file mode 100644
index 00000000..9a652b57
--- /dev/null
+++ b/vocoder/tf/convert_melgan_tflite.py
@@ -0,0 +1,33 @@
+# Convert Tensorflow Tacotron2 model to TF-Lite binary
+
+import argparse
+
+from TTS.utils.io import load_config
+from TTS.vocoder.tf.utils.generic_utils import setup_generator
+from TTS.vocoder.tf.utils.io import load_checkpoint
+from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--tf_model',
+ type=str,
+ help='Path to target torch model to be converted to TF.')
+parser.add_argument('--config_path',
+ type=str,
+ help='Path to config file of torch model.')
+parser.add_argument('--output_path',
+ type=str,
+ help='path to tflite output binary.')
+args = parser.parse_args()
+
+# Set constants
+CONFIG = load_config(args.config_path)
+
+# load the model
+model = setup_generator(CONFIG)
+model.build_inference()
+model = load_checkpoint(model, args.tf_model)
+
+# create tflite model
+tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)
+
diff --git a/vocoder/tf/layers/pqmf.py b/vocoder/tf/layers/pqmf.py
index 6c47dfc4..c018971f 100644
--- a/vocoder/tf/layers/pqmf.py
+++ b/vocoder/tf/layers/pqmf.py
@@ -51,7 +51,7 @@ class PQMF(tf.keras.layers.Layer):
def synthesis(self, x):
"""
- x : B x 1 x T
+ x : B x D x T
"""
x = tf.transpose(x, perm=[0, 2, 1])
x = tf.nn.conv1d_transpose(
diff --git a/vocoder/tf/models/melgan_generator.py b/vocoder/tf/models/melgan_generator.py
index bf67f3d2..168fd29e 100644
--- a/vocoder/tf/models/melgan_generator.py
+++ b/vocoder/tf/models/melgan_generator.py
@@ -108,4 +108,21 @@ class MelganGenerator(tf.keras.models.Model):
def build_inference(self):
x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32)
- self(x, training=False)
\ No newline at end of file
+ self(x, training=False)
+
+ @tf.function(
+ experimental_relax_shapes=True,
+ input_signature=[
+ tf.TensorSpec([1, None, None], dtype=tf.float32),
+ ],)
+ def inference_tflite(self, c):
+ c = tf.transpose(c, perm=[0, 2, 1])
+ c = tf.expand_dims(c, 2)
+ # FIXME: TF had no replicate padding as in Torch
+ # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
+ o = c
+ for layer in self.model_layers:
+ o = layer(o)
+ # o = self.model_layers(c)
+ o = tf.transpose(o, perm=[0, 3, 2, 1])
+ return o[:, :, 0, :]
\ No newline at end of file
diff --git a/vocoder/tf/models/multiband_melgan_generator.py b/vocoder/tf/models/multiband_melgan_generator.py
index c63ed06a..bdd333ed 100644
--- a/vocoder/tf/models/multiband_melgan_generator.py
+++ b/vocoder/tf/models/multiband_melgan_generator.py
@@ -30,11 +30,6 @@ class MultibandMelganGenerator(MelganGenerator):
def pqmf_synthesis(self, x):
return self.pqmf_layer.synthesis(x)
- # def call(self, c, training=False):
- # if training:
- # raise NotImplementedError()
- # return self.inference(c)
-
def inference(self, c):
c = tf.transpose(c, perm=[0, 2, 1])
c = tf.expand_dims(c, 2)
@@ -46,3 +41,20 @@ class MultibandMelganGenerator(MelganGenerator):
o = tf.transpose(o, perm=[0, 3, 2, 1])
o = self.pqmf_layer.synthesis(o[:, :, 0, :])
return o
+
+ @tf.function(
+ experimental_relax_shapes=True,
+ input_signature=[
+ tf.TensorSpec([1, 80, None], dtype=tf.float32),
+ ],)
+ def inference_tflite(self, c):
+ c = tf.transpose(c, perm=[0, 2, 1])
+ c = tf.expand_dims(c, 2)
+ # FIXME: TF had no replicate padding as in Torch
+ # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
+ o = c
+ for layer in self.model_layers:
+ o = layer(o)
+ o = tf.transpose(o, perm=[0, 3, 2, 1])
+ o = self.pqmf_layer.synthesis(o[:, :, 0, :])
+ return o
diff --git a/vocoder/tf/utils/tflite.py b/vocoder/tf/utils/tflite.py
new file mode 100644
index 00000000..d0637596
--- /dev/null
+++ b/vocoder/tf/utils/tflite.py
@@ -0,0 +1,31 @@
+import tensorflow as tf
+
+
+def convert_melgan_to_tflite(model,
+ output_path=None,
+ experimental_converter=True):
+ """Convert Tensorflow MelGAN model to TFLite. Save a binary file if output_path is
+ provided, else return TFLite model."""
+
+ concrete_function = model.inference_tflite.get_concrete_function()
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function])
+ converter.experimental_new_converter = experimental_converter
+ converter.optimizations = []
+ converter.target_spec.supported_ops = [
+ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
+ ]
+ tflite_model = converter.convert()
+ print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
+ if output_path is not None:
+ # same model binary if outputpath is provided
+ with open(output_path, 'wb') as f:
+ f.write(tflite_model)
+ return None
+ return tflite_model
+
+
+def load_tflite_model(tflite_path):
+ tflite_model = tf.lite.Interpreter(model_path=tflite_path)
+ tflite_model.allocate_tensors()
+ return tflite_model
\ No newline at end of file