diff --git a/README.md b/README.md index 9fb5deb2..5c631140 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ TTS comes with [pretrained models](https://github.com/mozilla/TTS/wiki/Released- [![License]()](https://opensource.org/licenses/MPL-2.0) [![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS) -:loudspeaker: [English Voice Samples](https://erogol.github.io/ddc-samples/) and [SoundCloud playlist](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2) +📢 [English Voice Samples](https://erogol.github.io/ddc-samples/) and [SoundCloud playlist](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2) -:man_cook: [TTS training recipes](https://github.com/erogol/TTS_recipes) +👩🏽‍🍳 [TTS training recipes](https://github.com/erogol/TTS_recipes) -:page_facing_up: [Text-to-Speech paper collection](https://github.com/erogol/TTS-papers) +📄 [Text-to-Speech paper collection](https://github.com/erogol/TTS-papers) ## 💬 Where to ask questions Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly, so that more people can benefit from it. @@ -93,21 +93,26 @@ Please use our dedicated channels for questions and discussion. Help is much mor You can also help us implement more models. Some TTS related work can be found [here](https://github.com/erogol/TTS-papers). ## Install TTS -TTS supports **python >= 3.6, <3.9**. +TTS is tested on Ubuntu 18.04 with **python >= 3.6, <3.9**. If you are only interested in [synthesizing speech](https://github.com/mozilla/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released TTS models, installing from PyPI is the easiest option. -``` +```bash pip install TTS ``` If you plan to code or train models, clone TTS and install it locally. -``` +```bash git clone https://github.com/mozilla/TTS pip install -e . ``` +We use ```espeak``` to convert graphemes to phonemes. You might need to install separately. +```bash +sudo apt-get install espeak +``` + ## Directory Structure ``` |- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.) @@ -157,16 +162,35 @@ Some of the public datasets that we successfully applied TTS: After the installation, TTS provides a CLI interface for synthesizing speech using pre-trained models. You can either use your own model or the release models under the TTS project. Listing released TTS models. -```tts --list_models``` +```bash +tts --list_models +``` Run a tts and a vocoder model from the released model list. (Simply copy and paste the full model names from the list as arguments for the command below.) -```tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --output_path``` +```bash +tts --text "Text for TTS" \ + --model_name "///" \ + --vocoder_name "///" \ + --out_path folder/to/save/output/ +``` Run your own TTS model (Using Griffin-Lim Vocoder) -```tts --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav``` +```bash +tts --text "Text for TTS" \ + --model_path path/to/model.pth.tar \ + --config_path path/to/config.json \ + --out_path output/path/speech.wav +``` Run your own TTS and Vocoder models -```tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav --vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json``` +```bash +tts --text "Text for TTS" \ + --model_path path/to/config.json \ + --config_path path/to/model.pth.tar \ + --out_path output/path/speech.wav \ + --vocoder_path path/to/vocoder.pth.tar \ + --vocoder_config_path path/to/vocoder_config.json +``` **Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder. @@ -185,19 +209,27 @@ To train a new model, you need to define your own ```config.json``` to define mo For instance, in order to train a tacotron or tacotron2 model on LJSpeech dataset, follow these steps. -```python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json``` +```bash +python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json +``` To fine-tune a model, use ```--restore_path```. -```python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json --restore_path /path/to/your/model.pth.tar``` +```bash +python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json --restore_path /path/to/your/model.pth.tar +``` To continue an old training run, use ```--continue_path```. -```python TTS/bin/train_tacotron.py --continue_path /path/to/your/run_folder/``` +```bash +python TTS/bin/train_tacotron.py --continue_path /path/to/your/run_folder/ +``` For multi-GPU training, call ```distribute.py```. It runs any provided train script in multi-GPU setting. -```CUDA_VISIBLE_DEVICES="0,1,4" python TTS/bin/distribute.py --script train_tacotron.py --config_path TTS/tts/configs/config.json``` +```bash +CUDA_VISIBLE_DEVICES="0,1,4" python TTS/bin/distribute.py --script train_tacotron.py --config_path TTS/tts/configs/config.json +``` Each run creates a new output folder accomodating used ```config.json```, model checkpoints and tensorboard logs. diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b7ccf850..9a06c866 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -35,6 +35,9 @@ def main(): # list provided models ./TTS/bin/synthesize.py --list_models + # run tts with default models. + ./TTS/bin synthesize.py --text "Text for TTS" + # run a model from the list ./TTS/bin/synthesize.py --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path @@ -67,14 +70,14 @@ def main(): parser.add_argument( '--model_name', type=str, - default=None, + default="tts_models/en/ljspeech/speedy-speech-wn", help= 'Name of one of the pre-trained tts models in format //' ) parser.add_argument( '--vocoder_name', type=str, - default=None, + default="vocoder_models/en/ljspeech/mulitband-melgan", help= 'Name of one of the pre-trained vocoder models in format //' ) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index ccb35a7c..be609905 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -534,7 +534,7 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_st = None # setup criterion - criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) + criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4) if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 73802c63..fe5fb3d7 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -345,6 +345,10 @@ def main(args): # pylint: disable=redefined-outer-name # setup criterion criterion = torch.nn.L1Loss().cuda() + if use_cuda: + model.cuda() + criterion.cuda() + if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: @@ -378,10 +382,6 @@ def main(args): # pylint: disable=redefined-outer-name else: args.restore_step = 0 - if use_cuda: - model.cuda() - criterion.cuda() - # DISTRUBUTED if num_gpus > 1: model = DDP_th(model, device_ids=[args.rank]) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index cad357dc..14d57837 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -200,7 +200,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) - sample_wav = model.generate(ground_mel, + sample_wav = model.inference(ground_mel, c.batched, c.target_samples, c.overlap_samples, @@ -287,7 +287,7 @@ def evaluate(model, criterion, ap, global_step, epoch): eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) - sample_wav = model.generate(ground_mel, + sample_wav = model.inference(ground_mel, c.batched, c.target_samples, c.overlap_samples, diff --git a/TTS/server/server.py b/TTS/server/server.py index 1f7357af..425879cf 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -17,8 +17,8 @@ def create_argparser(): parser = argparse.ArgumentParser() parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.') - parser.add_argument('--model_name', type=str, help='name of one of the released tts models.') - parser.add_argument('--vocoder_name', type=str, help='name of one of the released vocoder models.') + parser.add_argument('--model_name', type=str, default="tts_models/en/ljspeech/speedy-speech-wn", help='name of one of the released tts models.') + parser.add_argument('--vocoder_name', type=str, default="vocoder_models/en/ljspeech/mulitband-melgan", help='name of one of the released vocoder models.') parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file') parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 7e71df64..be587211 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,3 +1,5 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import pkg_resources installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable if 'tensorflow' in installed or 'tensorflow-gpu' in installed: diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 93a5880f..3d31ce6e 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -3,7 +3,7 @@ import soundfile as sf import numpy as np import scipy.io.wavfile import scipy.signal -import pyworld as pw +# import pyworld as pw from TTS.tts.utils.data import StandardScaler @@ -292,15 +292,16 @@ class AudioProcessor(object): return pad // 2, pad // 2 + pad % 2 ### Compute F0 ### - def compute_f0(self, x): - f0, t = pw.dio( - x.astype(np.double), - fs=self.sample_rate, - f0_ceil=self.mel_fmax, - frame_period=1000 * self.hop_length / self.sample_rate, - ) - f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - return f0 + # TODO: pw causes some dep issues + # def compute_f0(self, x): + # f0, t = pw.dio( + # x.astype(np.double), + # fs=self.sample_rate, + # f0_ceil=self.mel_fmax, + # frame_period=1000 * self.hop_length / self.sample_rate, + # ) + # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # return f0 ### Audio Processing ### def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index af741156..3cf8d67f 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,10 +1,11 @@ import json -import gdown -from pathlib import Path import os +from pathlib import Path -from TTS.utils.io import load_config +import gdown from TTS.utils.generic_utils import get_user_data_dir +from TTS.utils.io import load_config + class ModelManager(object): """Manage TTS models defined in .models.json. @@ -17,12 +18,17 @@ class ModelManager(object): Args: models_file (str): path to .model.json """ - def __init__(self, models_file): + def __init__(self, models_file=None): super().__init__() self.output_prefix = get_user_data_dir('tts') self.url_prefix = "https://drive.google.com/uc?id=" self.models_dict = None - self.read_models_file(models_file) + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) def read_models_file(self, file_path): """Read .models.json as a dict diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 615e0d1d..85e116cf 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -11,7 +11,7 @@ from TTS.tts.utils.speakers import load_speaker_mapping from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import -from TTS.tts.utils.synthesis import * +from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.text import make_symbols, phonemes, symbols @@ -79,7 +79,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(**self.tts_config.audio) + self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) if 'characters' in self.tts_config.keys(): symbols, phonemes = make_symbols(**self.tts_config.characters) @@ -96,7 +96,7 @@ class Synthesizer(object): def load_vocoder(self, model_file, model_config, use_cuda): self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(**self.vocoder_config['audio']) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config['audio']) self.vocoder_model = setup_generator(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 00000000..9de4f7b2 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,36 @@ +dependencies = ['torch', 'gdown'] +import torch + +from TTS.utils.synthesizer import Synthesizer +from TTS.utils.manage import ModelManager + + +def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan'): + """TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text. + + Example: + >>> synthesizer = torch.hub.load('mozilla/TTS', 'tts', source='github') + >>> wavs = synthesizer.tts("This is a test! This is also a test!!") + wavs - is a list of values of the synthesized speech. + + Args: + model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'. + vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'. + pretrained (bool, optional): [description]. Defaults to True. + + Returns: + TTS.utils.synthesizer.Synthesizer: Synthesizer object wrapping both vocoder and tts models. + """ + manager = ModelManager() + + model_path, config_path = manager.download_model(model_name) + vocoder_path, vocoder_config_path = manager.download_model(vocoder_name) + + # create synthesizer + synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path) + return synt + + +if __name__ == '__main__': + synthesizer = torch.hub.load('mozilla/TTS:hub_conf', 'tts', source='github') + synthesizer.tts("This is a test!") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fc0aca47..8b8da28d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,2 @@ [build-system] -requires = ["setuptools", "wheel", "Cython", "numpy>=1.16.0"] \ No newline at end of file +requires = ["setuptools", "wheel", "Cython", "numpy==1.17.5"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 31b49916..7a0d9f76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,11 @@ torch>=1.5 tensorflow==2.3.1 -numpy>=1.16.0 +numpy==1.17.5 scipy>=0.19.0 numba==0.48 librosa==0.7.2 phonemizer>=2.2.0 unidecode==0.4.20 -attrdict tensorboardX matplotlib Pillow @@ -23,4 +22,4 @@ pylint==2.5.3 gdown umap-learn cython -pyyaml +pyyaml \ No newline at end of file diff --git a/setup.py b/setup.py index 6cc06f89..53a142a1 100644 --- a/setup.py +++ b/setup.py @@ -5,15 +5,21 @@ import os import shutil import subprocess import sys +from distutils.version import LooseVersion import numpy import setuptools.command.build_py import setuptools.command.develop - -from setuptools import find_packages, setup -from distutils.extension import Extension +from setuptools import setup, Extension, find_packages from Cython.Build import cythonize + +if LooseVersion(sys.version) < LooseVersion("3.6") or LooseVersion(sys.version) > LooseVersion("3.9"): + raise RuntimeError( + "TTS requires python >= 3.6 and <3.9 " + "but your Python version is {}".format(sys.version) + ) + # parameters for wheeling server. parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) parser.add_argument('--checkpoint', diff --git a/tests/test_vocoder_gan_datasets.py b/tests/test_vocoder_gan_datasets.py index 2a487d9a..99a25dcf 100644 --- a/tests/test_vocoder_gan_datasets.py +++ b/tests/test_vocoder_gan_datasets.py @@ -61,7 +61,8 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us mel = ap.melspectrogram(audio) # the first 2 and the last 2 frames are skipped due to the padding # differences in stft - assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum() <= 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum()}' + max_diff = abs((feat - mel[:, :feat1.shape[-1]])[:, 2:-2]).max() + assert max_diff <= 0, f' [!] {max_diff}' count_iter += 1 # if count_iter == max_iter: