diff --git a/.compute b/.compute deleted file mode 100644 index 9786a689..00000000 --- a/.compute +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -yes | apt-get install sox -yes | apt-get install ffmpeg -yes | apt-get install tmux -yes | apt-get install zsh -sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" -pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl -sudo sh install.sh -# pip install pytorch==1.7.0+cu100 -# python3 setup.py develop -# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/ -# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/ -# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/ -# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360 -# python3 distribute.py --config_path config.json -while true; do sleep 1000000; done diff --git a/.github/workflows/data_tests.yml b/.github/workflows/data_tests.yml new file mode 100644 index 00000000..296aa570 --- /dev/null +++ b/.github/workflows/data_tests.yml @@ -0,0 +1,46 @@ +name: data-tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make data_tests diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml new file mode 100644 index 00000000..3f08b904 --- /dev/null +++ b/.github/workflows/inference_tests.yml @@ -0,0 +1,46 @@ +name: inference_tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make inference_tests diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml new file mode 100644 index 00000000..e06a25ad --- /dev/null +++ b/.github/workflows/text_tests.yml @@ -0,0 +1,48 @@ +name: tts-tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + experimental: [false] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: coqui-ai/setup-python@pip-cache-key-py-ver + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + sudo apt-get install espeak + sudo apt-get install espeak-ng + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make test_text diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml index e352a117..0a5891ee 100644 --- a/.github/workflows/tts_tests.yml +++ b/.github/workflows/tts_tests.yml @@ -35,6 +35,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y --no-install-recommends git make gcc + sudo apt-get install espeak + sudo apt-get install espeak-ng make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel diff --git a/.github/workflows/zoo_tests.yml b/.github/workflows/zoo_tests.yml index f973dd0e..94d54200 100644 --- a/.github/workflows/zoo_tests.yml +++ b/.github/workflows/zoo_tests.yml @@ -35,6 +35,7 @@ jobs: run: | sudo apt-get update sudo apt-get install -y git make gcc + sudo apt-get install espeak espeak-ng make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel diff --git a/.gitignore b/.gitignore index 7e9da0d8..f8d6e644 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ internal/* *_pitch.npy *_phoneme.npy wandb -depot/* \ No newline at end of file +depot/* +coqui_recipes/* \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 6e9f953e..d5f9c490 100644 --- a/.pylintrc +++ b/.pylintrc @@ -168,7 +168,8 @@ disable=missing-docstring, exception-escape, comprehension-escape, duplicate-code, - not-callable + not-callable, + import-outside-toplevel # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/Makefile b/Makefile index 32b4638b..d04cd976 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,15 @@ test_aux: ## run aux tests. test_zoo: ## run zoo tests. nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id +inference_tests: ## run inference tests. + nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id + +data_tests: ## run data tests. + nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id + +test_text: ## run text tests. + nosetests tests.text_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.text_tests --nologcapture --with-id + test_failed: ## only run tests failed the last time. nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed @@ -41,7 +50,6 @@ system-deps: ## install linux system deps dev-deps: ## install development deps pip install -r requirements.dev.txt - pip install -r requirements.tf.txt doc-deps: ## install docs dependencies pip install -r docs/requirements.txt diff --git a/README.md b/README.md index 4686ac67..80fa5dea 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models - Detailed training logs on the terminal and Tensorboard. - Support for Multi-speaker TTS. - Efficient, flexible, lightweight but feature complete `Trainer API`. -- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference. -- Released and read-to-use models. +- Released and ready-to-use models. - Tools to curate Text2Speech datasets under```dataset_analysis```. - Utilities to use and test your models. - Modular (but not too much) code base enabling easy implementation of new ideas. @@ -113,17 +112,11 @@ If you are only interested in [synthesizing speech](https://tts.readthedocs.io/e pip install TTS ``` -By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra. - -```bash -pip install TTS[tf] -``` - If you plan to code or train models, clone 🐸TTS and install it locally. ```bash git clone https://github.com/coqui-ai/TTS -pip install -e .[all,dev,notebooks,tf] # Select the relevant extras +pip install -e .[all,dev,notebooks] # Select the relevant extras ``` If you are on Ubuntu (Debian), you can also run following commands for installation. @@ -204,12 +197,10 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht |- train*.py (train your target model.) |- distribute.py (train your TTS model using Multiple GPUs.) |- compute_statistics.py (compute dataset statistics for normalization.) - |- convert*.py (convert target torch model to TF.) |- ... |- tts/ (text to speech models) |- layers/ (model layer definitions) |- models/ (model definitions) - |- tf/ (Tensorflow 2 utilities and model implementations) |- utils/ (model specific utilities.) |- speaker_encoder/ (Speaker Encoder models.) |- (same) diff --git a/TTS/.models.json b/TTS/.models.json index 61a3257d..801b8468 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -4,7 +4,7 @@ "multi-dataset":{ "your_tts":{ "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", "default_vocoder": null, "commit": "e9a1953e", "license": "CC BY-NC-ND 4.0", @@ -33,7 +33,7 @@ }, "tacotron2-DDC_ph": { "description": "Tacotron2 with Double Decoder Consistency with phonemes.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip", "default_vocoder": "vocoder_models/en/ljspeech/univnet", "commit": "3900448", "author": "Eren Gölge @erogol", @@ -71,7 +71,7 @@ }, "vits": { "description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--vits.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--vits.zip", "default_vocoder": null, "commit": "3900448", "author": "Eren Gölge @erogol", @@ -89,18 +89,9 @@ } }, "vctk": { - "sc-glow-tts": { - "description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip", - "default_vocoder": "vocoder_models/en/vctk/hifigan_v2", - "commit": "b531fa69", - "author": "Edresson Casanova", - "license": "", - "contact": "" - }, "vits": { "description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--vctk--vits.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--vits.zip", "default_vocoder": null, "commit": "3900448", "author": "Eren @erogol", @@ -109,7 +100,7 @@ }, "fast_pitch":{ "description": "FastPitch model trained on VCTK dataseset.", - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--fast_pitch.zip", "default_vocoder": null, "commit": "bdab788d", "author": "Eren @erogol", @@ -156,7 +147,7 @@ "uk":{ "mai": { "glow-tts": { - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--uk--mailabs--glow-tts.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--uk--mai--glow-tts.zip", "author":"@robinhad", "commit": "bdab788d", "license": "MIT", @@ -168,7 +159,7 @@ "zh-CN": { "baker": { "tacotron2-DDC-GST": { - "github_rls_url": "https://coqui.gateway.scarf.sh/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", "commit": "unknown", "author": "@kirianguiller", "default_vocoder": null @@ -206,6 +197,52 @@ "commit": "401fbd89" } } + }, + "tr":{ + "common-voice": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--tr--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/tr/common-voice/hifigan", + "license": "MIT", + "description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "commit": null + } + } + }, + "it": { + "mai_female": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + } + }, + "mai_male": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "commit": null + } + } } }, "vocoder_models": { @@ -324,6 +361,17 @@ "contact": "" } } + }, + "tr":{ + "common-voice": { + "hifigan":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/vocoder_models--tr--common-voice--hifigan.zip", + "description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "license": "MIT", + "commit": null + } + } } } } \ No newline at end of file diff --git a/TTS/VERSION b/TTS/VERSION index 79a2734b..09a3acfa 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.5.0 \ No newline at end of file +0.6.0 \ No newline at end of file diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index fc8c6629..e58259a6 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -11,7 +11,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_checkpoint diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 83a5aeae..50817154 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -29,6 +29,9 @@ parser.add_argument( help="Path to dataset config file.", ) parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") +parser.add_argument( + "--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None +) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True) @@ -40,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli wav_files = meta_data_train + meta_data_eval speaker_manager = SpeakerManager( - encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + encoder_model_path=args.model_path, + encoder_config_path=args.config_path, + d_vectors_file_path=args.old_file, + use_cuda=args.use_cuda, ) # compute speaker embeddings @@ -52,11 +58,15 @@ for idx, wav_file in enumerate(tqdm(wav_files)): else: speaker_name = None - # extract the embedding - embedd = speaker_manager.compute_d_vector_from_clip(wav_file) + wav_file_name = os.path.basename(wav_file) + if args.old_file is not None and wav_file_name in speaker_manager.clip_ids: + # get the embedding from the old file + embedd = speaker_manager.get_d_vector_by_clip(wav_file_name) + else: + # extract the embedding + embedd = speaker_manager.compute_d_vector_from_clip(wav_file) # create speaker_mapping if target dataset is defined - wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name]["name"] = speaker_name speaker_mapping[wav_file_name]["embedding"] = embedd diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index e1974ae7..3ab7ea7a 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -51,7 +51,7 @@ def main(): N = 0 for item in tqdm(dataset_items): # compute features - wav = ap.load_wav(item if isinstance(item, str) else item[1]) + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) linear = ap.spectrogram(wav) mel = ap.melspectrogram(wav) @@ -59,13 +59,13 @@ def main(): N += mel.shape[1] mel_sum += mel.sum(1) linear_sum += linear.sum(1) - mel_square_sum += (mel ** 2).sum(axis=1) - linear_square_sum += (linear ** 2).sum(axis=1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) mel_mean = mel_sum / N - mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2) + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) linear_mean = linear_sum / N - linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) output_file_path = args.out_path stats = {} diff --git a/TTS/bin/convert_melgan_tflite.py b/TTS/bin/convert_melgan_tflite.py deleted file mode 100644 index a3a3fb66..00000000 --- a/TTS/bin/convert_melgan_tflite.py +++ /dev/null @@ -1,25 +0,0 @@ -# Convert Tensorflow Tacotron2 model to TF-Lite binary - -import argparse - -from TTS.utils.io import load_config -from TTS.vocoder.tf.utils.generic_utils import setup_generator -from TTS.vocoder.tf.utils.io import load_checkpoint -from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite - -parser = argparse.ArgumentParser() -parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to tflite output binary.") -args = parser.parse_args() - -# Set constants -CONFIG = load_config(args.config_path) - -# load the model -model = setup_generator(CONFIG) -model.build_inference() -model = load_checkpoint(model, args.tf_model) - -# create tflite model -tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path) diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py deleted file mode 100644 index c1fb8498..00000000 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import os -from difflib import SequenceMatcher - -import numpy as np -import tensorflow as tf -import torch - -from TTS.utils.io import load_config, load_fsspec -from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, - convert_tf_name, - transfer_weights_torch_to_tf, -) -from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator -from TTS.vocoder.tf.utils.io import save_checkpoint -from TTS.vocoder.utils.generic_utils import setup_generator - -# prevent GPU use -os.environ["CUDA_VISIBLE_DEVICES"] = "" - -# define args -parser = argparse.ArgumentParser() -parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") -args = parser.parse_args() - -# load model config -config_path = args.config_path -c = load_config(config_path) -num_speakers = 0 - -# init torch model -model = setup_generator(c) -checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) -state_dict = checkpoint["model"] -model.load_state_dict(state_dict) -model.remove_weight_norm() -state_dict = model.state_dict() - -# init tf model -model_tf = setup_tf_generator(c) - -common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" -# get tf_model graph by passing an input -# B x D x T -dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32) -mel_pred = model_tf(dummy_input, training=False) - -# get tf variables -tf_vars = model_tf.weights - -# match variable names with fuzzy logic -torch_var_names = list(state_dict.keys()) -tf_var_names = [we.name for we in model_tf.weights] -var_map = [] -for tf_name in tf_var_names: - # skip re-mapped layer names - if tf_name in [name[0] for name in var_map]: - continue - tf_name_edited = convert_tf_name(tf_name) - ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] - max_idx = np.argmax(ratios) - matching_name = torch_var_names[max_idx] - del torch_var_names[max_idx] - var_map.append((tf_name, matching_name)) - -# pass weights -tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) - -# Compare TF and TORCH models -# check embedding outputs -model.eval() -dummy_input_torch = torch.ones((1, 80, 10)) -dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) -dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1]) -dummy_input_tf = tf.expand_dims(dummy_input_tf, 2) - -out_torch = model.layers[0](dummy_input_torch) -out_tf = model_tf.model_layers[0](dummy_input_tf) -out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] - -assert compare_torch_tf(out_torch, out_tf_) < 1e-5 - -for i in range(1, len(model.layers)): - print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}") - out_torch = model.layers[i](out_torch) - out_tf = model_tf.model_layers[i](out_tf) - out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] - diff = compare_torch_tf(out_torch, out_tf_) - assert diff < 1e-5, diff - -torch.manual_seed(0) -dummy_input_torch = torch.rand((1, 80, 100)) -dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) -model.inference_padding = 0 -model_tf.inference_padding = 0 -output_torch = model.inference(dummy_input_torch) -output_tf = model_tf(dummy_input_tf, training=False) -assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf) - -# save tf model -save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path) -print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/convert_tacotron2_tflite.py b/TTS/bin/convert_tacotron2_tflite.py deleted file mode 100644 index 327d0ae8..00000000 --- a/TTS/bin/convert_tacotron2_tflite.py +++ /dev/null @@ -1,30 +0,0 @@ -# Convert Tensorflow Tacotron2 model to TF-Lite binary - -import argparse - -from TTS.tts.tf.utils.generic_utils import setup_model -from TTS.tts.tf.utils.io import load_checkpoint -from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite -from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config - -parser = argparse.ArgumentParser() -parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to tflite output binary.") -args = parser.parse_args() - -# Set constants -CONFIG = load_config(args.config_path) - -# load the model -c = CONFIG -num_speakers = 0 -num_chars = len(phonemes) if c.use_phonemes else len(symbols) -model = setup_model(num_chars, num_speakers, c, enable_tflite=True) -model.build_inference() -model = load_checkpoint(model, args.tf_model) -model.decoder.set_max_decoder_steps(1000) - -# create tflite model -tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path) diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py deleted file mode 100644 index 78c6b362..00000000 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ /dev/null @@ -1,187 +0,0 @@ -import argparse -import os -import sys -from difflib import SequenceMatcher -from pprint import pprint - -import numpy as np -import tensorflow as tf -import torch - -from TTS.tts.models import setup_model -from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf -from TTS.tts.tf.utils.generic_utils import save_checkpoint -from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config, load_fsspec - -sys.path.append("/home/erogol/Projects") -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -parser = argparse.ArgumentParser() -parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") -parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") -parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") -args = parser.parse_args() - -# load model config -config_path = args.config_path -c = load_config(config_path) -num_speakers = 0 - -# init torch model -model = setup_model(c) -checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) -state_dict = checkpoint["model"] -model.load_state_dict(state_dict) - -# init tf model -num_chars = len(phonemes) if c.use_phonemes else len(symbols) -model_tf = Tacotron2( - num_chars=num_chars, - num_speakers=num_speakers, - r=model.decoder.r, - out_channels=c.audio["num_mels"], - decoder_output_dim=c.audio["num_mels"], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, -) - -# set initial layer mapping - these are not captured by the below heuristic approach -# TODO: set layer names so that we can remove these manual matching -common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" -var_map = [ - ("embedding/embeddings:0", "embedding.weight"), - ("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"), - ("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"), - ("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"), - ("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"), - ("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")), - ( - "encoder/lstm/backward_lstm/lstm_cell_2/bias:0", - ("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"), - ), - ("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"), - ("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"), - ("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"), -] - -# %% -# get tf_model graph -model_tf.build_inference() - -# get tf variables -tf_vars = model_tf.weights - -# match variable names with fuzzy logic -torch_var_names = list(state_dict.keys()) -tf_var_names = [we.name for we in model_tf.weights] -for tf_name in tf_var_names: - # skip re-mapped layer names - if tf_name in [name[0] for name in var_map]: - continue - tf_name_edited = convert_tf_name(tf_name) - ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] - max_idx = np.argmax(ratios) - matching_name = torch_var_names[max_idx] - del torch_var_names[max_idx] - var_map.append((tf_name, matching_name)) - -pprint(var_map) -pprint(torch_var_names) - -# pass weights -tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) - -# Compare TF and TORCH models -# %% -# check embedding outputs -model.eval() -input_ids = torch.randint(0, 24, (1, 128)).long() - -o_t = model.embedding(input_ids) -o_tf = model_tf.embedding(input_ids.detach().numpy()) -assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() - -# compare encoder outputs -oo_en = model.encoder.inference(o_t.transpose(1, 2)) -ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) -assert compare_torch_tf(oo_en, ooo_en) < 1e-5 - -# pylint: disable=redefined-builtin -# compare decoder.attention_rnn -inp = torch.rand([1, 768]) -inp_tf = inp.numpy() -model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access -output, cell_state = model.decoder.attention_rnn(inp) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) -assert compare_torch_tf(output, output_tf).mean() < 1e-5 - -query = output -inputs = torch.rand([1, 128, 512]) -query_tf = query.detach().numpy() -inputs_tf = inputs.numpy() - -# compare decoder.attention -model.decoder.attention.init_states(inputs) -processes_inputs = model.decoder.attention.preprocess_inputs(inputs) -loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) -context = model.decoder.attention(query, inputs, processes_inputs, None) - -attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1] -model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) -loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states) -context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False) - -assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5 -assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5 -assert compare_torch_tf(context, context_tf) < 1e-5 - -# compare decoder.decoder_rnn -input = torch.rand([1, 1536]) -input_tf = input.numpy() -model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access -output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) -assert abs(input - input_tf).mean() < 1e-5 -assert compare_torch_tf(output, output_tf).mean() < 1e-5 - -# compare decoder.linear_projection -input = torch.rand([1, 1536]) -input_tf = input.numpy() -output = model.decoder.linear_projection(input) -output_tf = model_tf.decoder.linear_projection(input_tf, training=False) -assert compare_torch_tf(output, output_tf) < 1e-5 - -# compare decoder outputs -model.decoder.max_decoder_steps = 100 -model_tf.decoder.set_max_decoder_steps(100) -output, align, stop = model.decoder.inference(oo_en) -states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False) -assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4 - -# compare the whole model output -outputs_torch = model.inference(input_ids) -outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) -print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean()) -assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 -assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 - -# %% -# save tf model -save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path) -print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 06d5f388..97e2f0e3 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -7,15 +7,14 @@ import subprocess import time import torch - -from TTS.trainer import TrainingArgs +from trainer import TrainerArgs def main(): """ Call train.py as a new process and pass command arguments """ - parser = TrainingArgs().init_argparse(arg_prefix="") + parser = TrainerArgs().init_argparse(arg_prefix="") parser.add_argument("--script", type=str, help="Target training script to distibute.") args, unargs = parser.parse_known_args() diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 7b489fd6..fa63c46a 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -13,6 +13,7 @@ from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.models import setup_model from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters @@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available() def setup_loader(ap, r, verbose=False): + tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( - r, - c.text_cleaner, + outputs_per_step=r, compute_linear_spec=False, - meta_data=meta_data, + samples=meta_data, + tokenizer=tokenizer, ap=ap, - characters=c.characters if "characters" in c.keys() else None, - add_blank=c["add_blank"] if "add_blank" in c.keys() else False, batch_group_size=0, - min_seq_len=c.min_seq_len, - max_seq_len=c.max_seq_len, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, phoneme_cache_path=c.phoneme_cache_path, - use_phonemes=c.use_phonemes, - phoneme_language=c.phoneme_language, - enable_eos_bos=c.enable_eos_bos_chars, + precompute_num_workers=0, use_noise_augment=False, verbose=verbose, speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None, @@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False): if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) - dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False)) + dataset.preprocess_samples() loader = DataLoader( dataset, @@ -75,8 +75,8 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data["text"] - text_lengths = data["text_lengths"] + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] item_idx = data["item_idxs"] @@ -138,7 +138,7 @@ def inference( aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, ) model_output = outputs["model_outputs"] - model_output = model_output.transpose(1, 2).detach().cpu().numpy() + model_output = model_output.detach().cpu().numpy() elif "tacotron" in model_name: aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} @@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # load data instances - meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval) + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) # use eval and training partitions meta_data = meta_data_train + meta_data_eval diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 437c2d60..4689dcad 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -23,7 +23,10 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + items = train_items + eval_items texts = "".join(item[0] for item in items) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index d3143ca3..0ae74bd4 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -7,14 +7,15 @@ from tqdm.contrib.concurrent import process_map from TTS.config import load_config from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.text import text2phone +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + +phonemizer = Gruut(language="en-us") def compute_phonemes(item): try: text = item[0] - language = item[-1] - ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|") + ph = phonemizer.phonemize(text).split("|") except: return [] return list(set(ph)) @@ -39,10 +40,17 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) items = train_items + eval_items print("Num items:", len(items)) + is_lang_def = all(item["language"] for item in items) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phones = [] for ph in phonemes: diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py index 3c5ef29c..c9f1166a 100644 --- a/TTS/bin/resample.py +++ b/TTS/bin/resample.py @@ -26,6 +26,7 @@ if __name__ == "__main__": --input_dir /root/LJSpeech-1.1/ --output_sr 22050 --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav --n_jobs 24 """, formatter_class=RawTextHelpFormatter, @@ -55,6 +56,14 @@ if __name__ == "__main__": help="Path of the destination folder. If not defined, the operation is done in place", ) + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + parser.add_argument( "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" ) @@ -67,7 +76,7 @@ if __name__ == "__main__": args.input_dir = args.output_dir print("Resampling the audio files...") - audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True) + audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True) print(f"Found {len(audio_files)} files...") audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) with Pool(processes=args.n_jobs) as p: diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 8c364300..5828411c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,6 +8,7 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.torch import NoamLR from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss @@ -19,7 +20,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.io import load_fsspec from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, check_update +from TTS.utils.training import check_update torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 0f8c4760..976b74af 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,19 +1,22 @@ import os -import torch +from dataclasses import dataclass, field -from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio import AudioProcessor + + +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args @@ -41,45 +44,15 @@ def main(): config = register_config(config_base.model)() # load training samples - train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True) - - # setup audio processor - ap = AudioProcessor(**config.audio) - - # init speaker manager - if check_config_and_model_args(config, "use_speaker_embedding", True): - speaker_manager = SpeakerManager(data_items=train_samples + eval_samples) - if hasattr(config, "model_args"): - config.model_args.num_speakers = speaker_manager.num_speakers - else: - config.num_speakers = speaker_manager.num_speakers - elif check_config_and_model_args(config, "use_d_vector_file", True): - if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True): - speaker_manager = SpeakerManager( - d_vectors_file_path=config.model_args.d_vector_file, - encoder_model_path=config.model_args.speaker_encoder_model_path, - encoder_config_path=config.model_args.speaker_encoder_config_path, - use_cuda=torch.cuda.is_available(), - ) - else: - speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) - config.num_speakers = speaker_manager.num_speakers - if hasattr(config, "model_args"): - config.model_args.num_speakers = speaker_manager.num_speakers - else: - speaker_manager = None - - if check_config_and_model_args(config, "use_language_embedding", True): - language_manager = LanguageManager(config=config) - if hasattr(config, "model_args"): - config.model_args.num_languages = language_manager.num_languages - else: - config.num_languages = language_manager.num_languages - else: - language_manager = None + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) # init the model from config - model = setup_model(config, speaker_manager, language_manager) + model = setup_model(config, train_samples + eval_samples) # init the trainer and 🚀 trainer = Trainer( @@ -89,7 +62,6 @@ def main(): model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, parse_command_line_args=False, ) trainer.fit() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index cd665f29..32ecd7bd 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,16 +1,23 @@ import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config -from TTS.trainer import Trainer, TrainingArgs from TTS.utils.audio import AudioProcessor from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args - train_args = TrainingArgs() + train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index f2bd40ad..6394b264 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass from typing import List from coqpit import Coqpit, check_argument +from trainer import TrainerConfig @dataclass @@ -57,6 +58,12 @@ class BaseAudioConfig(Coqpit): do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```0```. + trim_db (int): Silence threshold used for silence trimming. Defaults to 45. @@ -135,6 +142,9 @@ class BaseAudioConfig(Coqpit): spec_gain: int = 20 do_amp_to_db_linear: bool = True do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 0.0 # normalization params signal_norm: bool = True min_level_db: int = -100 @@ -228,130 +238,24 @@ class BaseDatasetConfig(Coqpit): @dataclass -class BaseTrainingConfig(Coqpit): - """Base config to define the basic training parameters that are shared - among all the models. +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. Args: model (str): Name of the model that is used in the training. - run_name (str): - Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. - - run_description (str): - Short description of the experiment. - - epochs (int): - Number training epochs. Defaults to 10000. - - batch_size (int): - Training batch size. - - eval_batch_size (int): - Validation batch size. - - mixed_precision (bool): - Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however - it may also cause numerical unstability in some cases. - - scheduler_after_epoch (bool): - If true, run the scheduler step after each epoch else run it after each model step. - - run_eval (bool): - Enable / Disable evaluation (validation) run. Defaults to True. - - test_delay_epochs (int): - Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful - results, hence waiting for a couple of epochs might save some time. - - print_eval (bool): - Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at - the end of the evaluation. Default to ```False```. - - print_step (int): - Number of steps required to print the next training log. - - log_dashboard (str): "tensorboard" or "wandb" - Set the experiment tracking tool - - plot_step (int): - Number of steps required to log training on Tensorboard. - - model_param_stats (bool): - Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. - Defaults to ```False```. - - project_name (str): - Name of the project. Defaults to config.model - - wandb_entity (str): - Name of W&B entity/team. Enables collaboration across a team or org. - - log_model_step (int): - Number of steps required to log a checkpoint as W&B artifact - - save_step (int):ipt - Number of steps required to save the next checkpoint. - - checkpoint (bool): - Enable / Disable checkpointing. - - keep_all_best (bool): - Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults - to ```False```. - - keep_after (int): - Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults - to 10000. - num_loader_workers (int): Number of workers for training time dataloader. num_eval_loader_workers (int): Number of workers for evaluation time dataloader. - - output_path (str): - Path for training output folder, either a local file path or other - URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or - S3 (s3://) paths. The nonexist part of the given path is created - automatically. All training artefacts are saved there. """ model: str = None - run_name: str = "coqui_tts" - run_description: str = "" - # training params - epochs: int = 10000 - batch_size: int = None - eval_batch_size: int = None - mixed_precision: bool = False - scheduler_after_epoch: bool = False - # eval params - run_eval: bool = True - test_delay_epochs: int = 0 - print_eval: bool = False - # logging - dashboard_logger: str = "tensorboard" - print_step: int = 25 - plot_step: int = 100 - model_param_stats: bool = False - project_name: str = None - log_model_step: int = None - wandb_entity: str = None - # checkpointing - save_step: int = 10000 - checkpoint: bool = True - keep_all_best: bool = False - keep_after: int = 10000 # dataloading num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False use_language_weighted_sampler: bool = False - - # paths - output_path: str = None - # distributed - distributed_backend: str = "nccl" - distributed_url: str = "tcp://localhost:54321" diff --git a/TTS/model.py b/TTS/model.py index 532d05a6..39cbeabc 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple -import numpy as np import torch from coqpit import Coqpit from torch import nn @@ -9,28 +8,21 @@ from torch import nn # pylint: skip-file -class BaseModel(nn.Module, ABC): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. +class BaseTrainerModel(ABC, nn.Module): + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit): + """Init the model from given config. - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` - """ - - def __init__(self, config: Coqpit): - super().__init__() - self._set_model_args(config) - - def _set_model_args(self, config: Coqpit): - """Set model arguments from the config. Override this.""" - pass + Override this depending on your model. + """ + ... @abstractmethod def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: - """Forward pass for the model mainly used in training. + """Forward ... for the model mainly used in training. You can be flexible here and use different number of arguments and argument names since it is intended to be used by `train_step()` without exposing it out of the model. @@ -48,7 +40,7 @@ class BaseModel(nn.Module, ABC): @abstractmethod def inference(self, input: torch.Tensor, aux_input={}) -> Dict: - """Forward pass for inference. + """Forward ... for inference. We don't use `*kwargs` since it is problematic with the TorchScript API. @@ -63,9 +55,25 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict + def format_batch(self, batch: Dict) -> Dict: + """Format batch returned by the data loader before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + + def format_batch_on_device(self, batch: Dict) -> Dict: + """Format batch on device before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + @abstractmethod def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. + """Perform a single training step. Run the model forward ... and compute losses. Args: batch (Dict): Input tensors. @@ -93,11 +101,11 @@ class BaseModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - pass + ... @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can + """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can call `train_step()` with no changes. Args: @@ -114,36 +122,49 @@ class BaseModel(nn.Module, ABC): def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """The same as `train_log()`""" - pass + ... @abstractmethod - def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: """Load a checkpoint and get ready for training or inference. Args: config (Coqpit): Model configuration. checkpoint_path (str): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. """ ... - def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - """Setup an return optimizer or optimizers.""" - pass + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel": + """Init the model from given config. - def get_lr(self) -> Union[float, List[float]]: - """Return learning rate(s). - - Returns: - Union[float, List[float]]: Model's initial learning rates. + Override this depending on your model. """ - pass + ... - def get_scheduler(self, optimizer: torch.optim.Optimizer): - pass + @abstractmethod + def get_data_loader( + self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int + ): + ... - def get_criterion(self): - pass + # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: + # """Setup an return optimizer or optimizers.""" + # ... - def format_batch(self): - pass + # def get_lr(self) -> Union[float, List[float]]: + # """Return learning rate(s). + + # Returns: + # Union[float, List[float]]: Model's initial learning rates. + # """ + # ... + + # def get_scheduler(self, optimizer: torch.optim.Optimizer): + # ... + + # def get_criterion(self): + # ... diff --git a/TTS/server/server.py b/TTS/server/server.py index f2512582..aef507fd 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path: if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) -# CASE3: set custome model paths +# CASE3: set custom model paths if args.model_path is not None: model_path = args.model_path config_path = args.config_path @@ -170,9 +170,9 @@ def tts(): text = request.args.get("text") speaker_idx = request.args.get("speaker_id", "") style_wav = request.args.get("style_wav", "") - style_wav = style_wav_uri_to_dict(style_wav) print(" > Model input: {}".format(text)) + print(" > Speaker Idx: {}".format(speaker_idx)) wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 5b0fee22..28a23e2f 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset): mel = self.ap.melspectrogram(wav).astype("float32") # sample seq_len - assert text.size > 0, self.items[idx][1] - assert wav.size > 0, self.items[idx][1] + assert text.size > 0, self.items[idx]["audio_file"] + assert wav.size > 0, self.items[idx]["audio_file"] sample = { "mel": mel, - "item_idx": self.items[idx][1], + "item_idx": self.items[idx]["audio_file"], "speaker_name": speaker_name, } return sample @@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset): def __parse_items(self): self.speaker_to_utters = {} for i in self.items: - path_ = i[1] - speaker_ = i[2] + path_ = i["audio_file"] + speaker_ = i["speaker_name"] if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index d6c3dad4..a799fc52 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module): x = torch.sum(x * w, dim=2) elif self.encoder_type == "ASP": mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) x = torch.cat((mu, sg), 1) x = x.view(x.size()[0], -1) diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index b8aa4093..4ab4e923 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -113,7 +113,7 @@ class AugmentWAV(object): def additive_noise(self, noise_type, audio): - clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) noise_list = random.sample( self.noise_list[noise_type], @@ -135,7 +135,7 @@ class AugmentWAV(object): self.additive_noise_config[noise_type]["min_snr_in_db"], self.additive_noise_config[noise_type]["max_num_noises"], ) - noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio if noises_wav is None: @@ -154,7 +154,7 @@ class AugmentWAV(object): rir_file = random.choice(self.rir_files) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) - rir = rir / np.sqrt(np.sum(rir ** 2)) + rir = rir / np.sqrt(np.sum(rir**2)) return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] def apply_one(self, audio): diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index a32f43bd..7c58a232 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,19 +1,24 @@ import os +from dataclasses import dataclass, field from coqpit import Coqpit +from trainer import TrainerArgs, get_last_checkpoint +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config -from TTS.trainer import TrainingArgs -from TTS.tts.utils.text.symbols import parse_symbols +from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from TTS.utils.logging import init_dashboard_logger -from TTS.utils.logging.console_logger import ConsoleLogger -from TTS.utils.trainer_utils import get_last_checkpoint + + +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) def getarguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser @@ -75,13 +80,13 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - dashboard_logger = init_dashboard_logger(config) + dashboard_logger = logger_factory(config, experiment_path) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger def init_arguments(): - train_config = TrainingArgs() + train_config = TrainArgs() parser = train_config.init_argparse(arg_prefix="") return parser diff --git a/TTS/trainer.py b/TTS/trainer.py deleted file mode 100644 index 7bffb386..00000000 --- a/TTS/trainer.py +++ /dev/null @@ -1,1199 +0,0 @@ -# -*- coding: utf-8 -*- - -import importlib -import multiprocessing -import os -import platform -import sys -import time -import traceback -from argparse import Namespace -from dataclasses import dataclass, field -from inspect import signature -from typing import Callable, Dict, List, Tuple, Union - -import torch -import torch.distributed as dist -from coqpit import Coqpit -from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader - -from TTS.utils.callbacks import TrainerCallback -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import ( - KeepAverage, - count_parameters, - get_experiment_folder_path, - get_git_branch, - remove_experiment_folder, - set_init_dict, - to_cuda, -) -from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint -from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger -from TTS.utils.trainer_utils import ( - get_last_checkpoint, - get_optimizer, - get_scheduler, - is_apex_available, - setup_torch_training_env, -) - -multiprocessing.set_start_method("fork") - -if platform.system() != "Windows": - # https://github.com/pytorch/pytorch/issues/973 - import resource - - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) - - -if is_apex_available(): - from apex import amp - - -@dataclass -class TrainingArgs(Coqpit): - """Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and - set the values for distributed training.""" - - continue_path: str = field( - default="", - metadata={ - "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." - }, - ) - restore_path: str = field( - default="", - metadata={ - "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." - }, - ) - best_path: str = field( - default="", - metadata={ - "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" - }, - ) - skip_train_epoch: bool = field( - default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."} - ) - config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) - rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) - group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) - use_ddp: bool = field( - default=False, - metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, - ) - - -class Trainer: - def __init__( # pylint: disable=dangerous-default-value - self, - args: Union[Coqpit, Namespace], - config: Coqpit, - output_path: str, - c_logger: ConsoleLogger = None, - dashboard_logger: Union[TensorboardLogger, WandbLogger] = None, - model: nn.Module = None, - get_model: Callable = None, - get_data_samples: Callable = None, - train_samples: List = None, - eval_samples: List = None, - cudnn_benchmark: bool = False, - training_assets: Dict = {}, - parse_command_line_args: bool = True, - ) -> None: - """Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models - or easily be customized. - - Notes: - - Supports Automatic Mixed Precision training. If `Apex` is availabe, it automatically picks that, else - it uses PyTorch's native `amp` module. `Apex` may provide more stable training in some cases. - - Args: - - args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainingArgs` - config object. - - config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating - and testing the model. - - output_path (str): Path to the output training folder. All the files are saved under thi path. - - c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default - console logger is used. Defaults to None. - - dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used. - Defaults to None. - - model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` - initializes a model from the provided config. Defaults to None. - - get_model (Callable): - A function that returns a model. It is used to initialize the model when `model` is not provided. - It either takes the config as the only argument or does not take any argument. - Defaults to None - - get_data_samples (Callable): - A function that returns a list of training and evaluation samples. Used if `train_samples` and - `eval_samples` are None. Defaults to None. - - train_samples (List): - A list of training samples used by the model's `get_data_loader` to init the `dataset` and the - `data_loader`. Defaults to None. - - eval_samples (List): - A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the - `data_loader`. Defaults to None. - - cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input - length is changing batch to batch along the training. - - training_assets (Dict): - A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()``` - during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}. - - parse_command_line_args (bool): - If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it - to false if you parse the arguments yourself. Defaults to True. - - Examples: - - Running trainer with HifiGAN model. - - >>> args = TrainingArgs(...) - >>> config = HifiganConfig(...) - >>> model = GANModel(config) - >>> ap = AudioProcessor(**config.audio) - >>> assets = {"audio_processor": ap} - >>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets) - >>> trainer.fit() - - TODO: - - Wrap model for not calling .module in DDP. - - Accumulate gradients b/w batches. - - Deepspeed integration - - Profiler integration. - - Overfitting to a batch. - - TPU training - - NOTE: Consider moving `training_assets` to the model implementation. - """ - - if parse_command_line_args: - # parse command-line arguments for TrainerArgs() - args, coqpit_overrides = self.parse_argv(args) - - # get ready for training and parse command-line arguments for the model config - config = self.init_training(args, coqpit_overrides, config) - - # set the output path - if args.continue_path: - # use the same path as the continuing run - output_path = args.continue_path - else: - # override the output path if it is provided - output_path = config.output_path if output_path is None else output_path - # create a new output folder name - output_path = get_experiment_folder_path(config.output_path, config.run_name) - os.makedirs(output_path, exist_ok=True) - - # copy training assets to the output folder - copy_model_files(config, output_path) - - # init class members - self.args = args - self.config = config - self.output_path = output_path - self.config.output_log_path = output_path - self.training_assets = training_assets - - # setup logging - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") - self._setup_logger_config(log_file) - time.sleep(1.0) # wait for the logger to be ready - - # set and initialize Pytorch runtime - self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp) - - # init loggers - self.c_logger = ConsoleLogger() if c_logger is None else c_logger - self.dashboard_logger = dashboard_logger - - # only allow dashboard logging for the main process in DDP mode - if self.dashboard_logger is None and args.rank == 0: - self.dashboard_logger = init_dashboard_logger(config) - - if not self.config.log_model_step: - self.config.log_model_step = self.config.save_step - - self.total_steps_done = 0 - self.epochs_done = 0 - self.restore_step = 0 - self.best_loss = float("inf") - self.train_loader = None - self.eval_loader = None - - self.keep_avg_train = None - self.keep_avg_eval = None - - self.use_apex = self._is_apex_available() - self.use_amp_scaler = self.config.mixed_precision and self.use_cuda - - # load data samples - if train_samples is None and get_data_samples is None: - raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.") - if train_samples is not None: - self.train_samples = train_samples - self.eval_samples = eval_samples - else: - self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples) - - # init TTS model - if model is None and get_model is None: - raise ValueError("[!] `model` and `get_model` cannot both be None.") - if model is not None: - self.model = model - else: - self.run_get_model(self.config, get_model) - - # setup criterion - self.criterion = self.get_criterion(self.model) - - # DISTRUBUTED - if self.num_gpus > 1: - init_distributed( - args.rank, - self.num_gpus, - args.group_id, - self.config.distributed_backend, - self.config.distributed_url, - ) - - if self.use_cuda: - self.model.cuda() - if isinstance(self.criterion, list): - self.criterion = [x.cuda() for x in self.criterion] - else: - self.criterion.cuda() - - # setup optimizer - self.optimizer = self.get_optimizer(self.model, self.config) - - # CALLBACK - self.callbacks = TrainerCallback() - self.callbacks.on_init_start(self) - - # init AMP - if self.use_amp_scaler: - if self.use_apex: - self.scaler = None - self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") - # if isinstance(self.optimizer, list): - # self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) - # else: - self.scaler = torch.cuda.amp.GradScaler() - else: - self.scaler = None - - if self.args.restore_path: - self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( - self.config, args.restore_path, self.model, self.optimizer, self.scaler - ) - - # setup scheduler - self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) - - if self.scheduler is not None: - if self.args.continue_path: - if isinstance(self.scheduler, list): - for scheduler in self.scheduler: - if scheduler is not None: - scheduler.last_epoch = self.restore_step - else: - self.scheduler.last_epoch = self.restore_step - - # DISTRIBUTED - if self.num_gpus > 1: - self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) - - # count model size - num_params = count_parameters(self.model) - print("\n > Model has {} parameters".format(num_params)) - - self.callbacks.on_init_end(self) - - @staticmethod - def parse_argv(args: Union[Coqpit, List]): - """Parse command line arguments to init or override `TrainingArgs()`.""" - if isinstance(args, Coqpit): - parser = args.init_argparse(arg_prefix="") - else: - train_config = TrainingArgs() - parser = train_config.init_argparse(arg_prefix="") - training_args, coqpit_overrides = parser.parse_known_args() - args.parse_args(training_args) - return args, coqpit_overrides - - def init_training( - self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None - ): # pylint: disable=no-self-use - """Initialize training and update model configs from command line arguments. - - Args: - args (argparse.Namespace or dict like): Parsed input arguments. - config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments. - config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. - - Returns: - c (TTS.utils.io.AttrDict): Config paramaters. - """ - # set arguments for continuing training - if args.continue_path: - experiment_path = args.continue_path - args.config_path = os.path.join(args.continue_path, "config.json") - args.restore_path, best_model = get_last_checkpoint(args.continue_path) - if not args.best_path: - args.best_path = best_model - - # override config values from command-line args - # TODO: Maybe it is better to do it outside - if len(coqpit_overrides) > 0: - config.parse_known_args(coqpit_overrides, arg_prefix="coqpit", relaxed_parser=True) - experiment_path = args.continue_path - - # update the config.json fields and copy it to the output folder - if args.rank == 0: - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(config, experiment_path, new_fields) - return config - - @staticmethod - def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module: - """Run the `get_model` function and return the model. - - Args: - config (Coqpit): Model config. - - Returns: - nn.Module: initialized model. - """ - if len(signature(get_model).sig.parameters) == 1: - model = get_model(config) - else: - model = get_model() - return model - - @staticmethod - def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module: - if callable(get_data_samples): - if len(signature(get_data_samples).sig.parameters) == 1: - train_samples, eval_samples = get_data_samples(config) - else: - train_samples, eval_samples = get_data_samples() - return train_samples, eval_samples - return None, None - - def restore_model( - self, - config: Coqpit, - restore_path: str, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: torch.cuda.amp.GradScaler = None, - ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: - """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats. - - Args: - config (Coqpit): Model config. - restore_path (str): Path to the restored training run. - model (nn.Module): Model to restored. - optimizer (torch.optim.Optimizer): Optimizer to restore. - scaler (torch.cuda.amp.GradScaler, optional): AMP scaler to restore. Defaults to None. - - Returns: - Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: [description] - """ - - def _restore_list_objs(states, obj): - if isinstance(obj, list): - for idx, state in enumerate(states): - obj[idx].load_state_dict(state) - else: - obj.load_state_dict(states) - return obj - - print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = load_fsspec(restore_path, map_location="cpu") - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) - if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: - print(" > Restoring Scaler...") - scaler = _restore_list_objs(checkpoint["scaler"], scaler) - except (KeyError, RuntimeError, ValueError): - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - if isinstance(self.optimizer, list): - for idx, optim in enumerate(optimizer): - for group in optim.param_groups: - group["lr"] = self.get_lr(model, config)[idx] - else: - for group in optimizer.param_groups: - group["lr"] = self.get_lr(model, config) - print( - " > Model restored from step %d" % checkpoint["step"], - ) - restore_step = checkpoint["step"] - torch.cuda.empty_cache() - return model, optimizer, scaler, restore_step - - ######################### - # DATA LOADING FUNCTIONS - ######################### - - def _get_loader( - self, - model: nn.Module, - config: Coqpit, - assets: Dict, - is_eval: bool, - data_items: List, - verbose: bool, - num_gpus: int, - ) -> DataLoader: - if num_gpus > 1: - if hasattr(model.module, "get_data_loader"): - loader = model.module.get_data_loader( - config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank - ) - else: - if hasattr(model, "get_data_loader"): - loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus) - return loader - - def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: - """Initialize and return a training data loader. - - Args: - ap (AudioProcessor): Audio processor. - data_items (List): Data samples used for training. - verbose (bool): enable/disable printing loader stats at initialization. - - Returns: - DataLoader: Initialized training data loader. - """ - return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus) - - def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: - return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus) - - def format_batch(self, batch: List) -> Dict: - """Format the dataloader output and return a batch. - - Args: - batch (List): Batch returned by the dataloader. - - Returns: - Dict: Formatted batch. - """ - if self.num_gpus > 1: - batch = self.model.module.format_batch(batch) - else: - batch = self.model.format_batch(batch) - if self.use_cuda: - for k, v in batch.items(): - batch[k] = to_cuda(v) - return batch - - ###################### - # TRAIN FUNCTIONS - ###################### - - @staticmethod - def master_params(optimizer: torch.optim.Optimizer): - """Generator over parameters owned by the optimizer. - - Used to select parameters used by the optimizer for gradient clipping. - - Args: - optimizer: Target optimizer. - """ - for group in optimizer.param_groups: - for p in group["params"]: - yield p - - @staticmethod - def _model_train_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None - ) -> Tuple[Dict, Dict]: - """ - Perform a trainig forward step. Compute model outputs and losses. - - Args: - batch (Dict): [description] - model (nn.Module): [description] - criterion (nn.Module): [description] - optimizer_idx (int, optional): [description]. Defaults to None. - - Returns: - Tuple[Dict, Dict]: [description] - """ - input_args = [batch, criterion] - if optimizer_idx is not None: - input_args.append(optimizer_idx) - # unwrap model in DDP training - if hasattr(model, "module"): - return model.module.train_step(*input_args) - return model.train_step(*input_args) - - def _optimize( - self, - batch: Dict, - model: nn.Module, - optimizer: Union[torch.optim.Optimizer, List], - scaler: "AMPScaler", - criterion: nn.Module, - scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access - config: Coqpit, - optimizer_idx: int = None, - ) -> Tuple[Dict, Dict, int]: - """Perform a forward - backward pass and run the optimizer. - - Args: - batch (Dict): Input batch. If - model (nn.Module): Model for training. Defaults to None. - optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. - scaler (AMPScaler): AMP scaler. - criterion (nn.Module): Model's criterion. - scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer. - config (Coqpit): Model config. - optimizer_idx (int, optional): Target optimizer being used. Defaults to None. - - Raises: - RuntimeError: When the loss is NaN. - - Returns: - Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. - """ - - step_start_time = time.time() - # zero-out optimizer - optimizer.zero_grad() - - # forward pass and loss computation - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - if optimizer_idx is not None: - outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) - else: - outputs, loss_dict = self._model_train_step(batch, model, criterion) - - # skip the rest - if outputs is None: - step_time = time.time() - step_start_time - return None, {}, step_time - - # # check nan loss - # if torch.isnan(loss_dict["loss"]).any(): - # raise RuntimeError(f" > NaN loss detected - {loss_dict}") - - # set gradient clipping threshold - if "grad_clip" in config and config.grad_clip is not None: - if optimizer_idx is not None: - grad_clip = config.grad_clip[optimizer_idx] - else: - grad_clip = config.grad_clip - else: - grad_clip = 0.0 # meaning no gradient clipping - - # optimizer step - grad_norm = 0 - update_lr_scheduler = True - if self.use_amp_scaler: - if self.use_apex: - # TODO: verify AMP use for GAN training in TTS - # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers - with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: - scaled_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), grad_clip) - else: - # model optimizer step in mixed precision mode - scaler.scale(loss_dict["loss"]).backward() - if grad_clip > 0: - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip) - scale_prev = scaler.get_scale() - scaler.step(optimizer) - scaler.update() - update_lr_scheduler = scale_prev <= scaler.get_scale() - loss_dict["amp_scaler"] = scaler.get_scale() # for logging - else: - # main model optimizer step - loss_dict["loss"].backward() - if grad_clip > 0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - optimizer.step() - - # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN - if isinstance(grad_norm, torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)): - grad_norm = 0 - - step_time = time.time() - step_start_time - - # setup lr - if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch: - scheduler.step() - - # detach losses - loss_dict = self._detach_loss_dict(loss_dict) - if optimizer_idx is not None: - loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") - loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm - else: - loss_dict["grad_norm"] = grad_norm - return outputs, loss_dict, step_time - - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: - """Perform a training step on a batch of inputs and log the process. - - Args: - batch (Dict): Input batch. - batch_n_steps (int): Number of steps needed to complete an epoch. Needed for logging. - step (int): Current step number in this epoch. - loader_start_time (float): The time when the data loading is started. Needed for logging. - - Returns: - Tuple[Dict, Dict]: Model outputs and losses. - """ - self.callbacks.on_train_step_start(self) - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - - # conteainers to hold model outputs and losses for each optimizer. - outputs_per_optimizer = None - loss_dict = {} - if not isinstance(self.optimizer, list): - # training with a single optimizer - outputs, loss_dict_new, step_time = self._optimize( - batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config - ) - loss_dict.update(loss_dict_new) - else: - # training with multiple optimizers (e.g. GAN) - outputs_per_optimizer = [None] * len(self.optimizer) - total_step_time = 0 - for idx, optimizer in enumerate(self.optimizer): - criterion = self.criterion - # scaler = self.scaler[idx] if self.use_amp_scaler else None - scaler = self.scaler - scheduler = self.scheduler[idx] - outputs, loss_dict_new, step_time = self._optimize( - batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx - ) - # skip the rest if the model returns None - total_step_time += step_time - outputs_per_optimizer[idx] = outputs - # merge loss_dicts from each optimizer - # rename duplicates with the optimizer idx - # if None, model skipped this optimizer - if loss_dict_new is not None: - for k, v in loss_dict_new.items(): - if k in loss_dict: - loss_dict[f"{k}-{idx}"] = v - else: - loss_dict[k] = v - step_time = total_step_time - outputs = outputs_per_optimizer - - # update avg runtime stats - keep_avg_update = {} - keep_avg_update["avg_loader_time"] = loader_time - keep_avg_update["avg_step_time"] = step_time - self.keep_avg_train.update_values(keep_avg_update) - - # update avg loss stats - update_eval_values = {} - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - self.keep_avg_train.update_values(update_eval_values) - - # print training progress - if self.total_steps_done % self.config.print_step == 0: - # log learning rates - lrs = {} - if isinstance(self.optimizer, list): - for idx, optimizer in enumerate(self.optimizer): - current_lr = self.optimizer[idx].param_groups[0]["lr"] - lrs.update({f"current_lr_{idx}": current_lr}) - else: - current_lr = self.optimizer.param_groups[0]["lr"] - lrs = {"current_lr": current_lr} - - # log run-time stats - loss_dict.update(lrs) - loss_dict.update( - { - "step_time": round(step_time, 4), - "loader_time": round(loader_time, 4), - } - ) - self.c_logger.print_train_step( - batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values - ) - - if self.args.rank == 0: - # Plot Training Iter Stats - # reduce TB load and don't log every step - if self.total_steps_done % self.config.plot_step == 0: - self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict) - if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: - if self.config.checkpoint: - # checkpoint the model - target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train) - save_checkpoint( - self.config, - self.model, - self.optimizer, - self.scaler if self.use_amp_scaler else None, - self.total_steps_done, - self.epochs_done, - self.output_path, - model_loss=target_avg_loss, - ) - - if self.total_steps_done % self.config.log_model_step == 0: - # log checkpoint as artifact - aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"] - self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) - - # training visualizations - if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): - self.model.module.train_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - elif hasattr(self.model, "train_log"): - self.model.train_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - - self.dashboard_logger.flush() - - self.total_steps_done += 1 - self.callbacks.on_train_step_end(self) - return outputs, loss_dict - - def train_epoch(self) -> None: - """Main entry point for the training loop. Run training on the all training samples.""" - # initialize the data loader - self.train_loader = self.get_train_dataloader( - self.training_assets, - self.train_samples, - verbose=True, - ) - # set model to training mode - if self.num_gpus > 1: - self.model.module.train() - else: - self.model.train() - epoch_start_time = time.time() - if self.use_cuda: - batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) - else: - batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) - self.c_logger.print_train_start() - loader_start_time = time.time() - # iterate over the training samples - for cur_step, batch in enumerate(self.train_loader): - _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) - loader_start_time = time.time() - epoch_time = time.time() - epoch_start_time - # plot self.epochs_done Stats - if self.args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(self.keep_avg_train.avg_values) - self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats) - if self.config.model_param_stats: - self.logger.model_weights(self.model, self.total_steps_done) - # scheduler step after the epoch - if self.scheduler is not None and self.config.scheduler_after_epoch: - if isinstance(self.scheduler, list): - for scheduler in self.scheduler: - if scheduler is not None: - scheduler.step() - else: - self.scheduler.step() - - ####################### - # EVAL FUNCTIONS - ####################### - - @staticmethod - def _model_eval_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None - ) -> Tuple[Dict, Dict]: - """ - Perform a evaluation forward pass. Compute model outputs and losses with no gradients. - - Args: - batch (Dict): IBatch of inputs. - model (nn.Module): Model to call evaluation. - criterion (nn.Module): Model criterion. - optimizer_idx (int, optional): Optimizer ID to define the closure in multi-optimizer training. Defaults to None. - - Returns: - Tuple[Dict, Dict]: model outputs and losses. - """ - input_args = [batch, criterion] - if optimizer_idx is not None: - input_args.append(optimizer_idx) - if hasattr(model, "module"): - return model.module.eval_step(*input_args) - return model.eval_step(*input_args) - - def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: - """Perform a evaluation step on a batch of inputs and log the process. - - Args: - batch (Dict): Input batch. - step (int): Current step number in this epoch. - - Returns: - Tuple[Dict, Dict]: Model outputs and losses. - """ - with torch.no_grad(): - outputs = [] - loss_dict = {} - if not isinstance(self.optimizer, list): - outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) - else: - outputs = [None] * len(self.optimizer) - for idx, _ in enumerate(self.optimizer): - criterion = self.criterion - outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) - outputs[idx] = outputs_ - - if loss_dict_new is not None: - loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss") - loss_dict.update(loss_dict_new) - - loss_dict = self._detach_loss_dict(loss_dict) - - # update avg stats - update_eval_values = {} - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - self.keep_avg_eval.update_values(update_eval_values) - - if self.config.print_eval: - self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) - return outputs, loss_dict - - def eval_epoch(self) -> None: - """Main entry point for the evaluation loop. Run evaluation on the all validation samples.""" - self.eval_loader = ( - self.get_eval_dataloader( - self.training_assets, - self.eval_samples, - verbose=True, - ) - if self.config.run_eval - else None - ) - - self.model.eval() - self.c_logger.print_eval_start() - loader_start_time = time.time() - batch = None - for cur_step, batch in enumerate(self.eval_loader): - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) - outputs, _ = self.eval_step(batch, cur_step) - loader_start_time = time.time() - # plot epoch stats, artifacts and figures - if self.args.rank == 0: - if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"): - self.model.module.eval_log( - batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - elif hasattr(self.model, "eval_log"): - self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) - self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) - - def test_run(self) -> None: - """Run test and log the results. Test run must be defined by the model. - Model must return figures and audios to be logged by the Tensorboard.""" - if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): - if self.eval_loader is None: - self.eval_loader = self.get_eval_dataloader( - self.training_assets, - self.eval_samples, - verbose=True, - ) - - if hasattr(self.eval_loader.dataset, "load_test_samples"): - samples = self.eval_loader.dataset.load_test_samples(1) - if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.training_assets, samples, None) - else: - figures, audios = self.model.test_run(self.training_assets, samples, None) - else: - if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.training_assets) - else: - figures, audios = self.model.test_run(self.training_assets) - self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) - self.dashboard_logger.test_figures(self.total_steps_done, figures) - - def _restore_best_loss(self): - """Restore the best loss from the args.best_path if provided else - from the model (`args.restore_path` or `args.continue_path`) used for resuming the training""" - if self.restore_step != 0 or self.args.best_path: - print(f" > Restoring best loss from {os.path.basename(self.args.best_path)} ...") - ch = load_fsspec(self.args.restore_path, map_location="cpu") - if "model_loss" in ch: - self.best_loss = ch["model_loss"] - print(f" > Starting with loaded last best loss {self.best_loss}.") - - ################################### - # FIT FUNCTIONS - ################################### - - def _fit(self) -> None: - """🏃 train -> evaluate -> test for the number of epochs.""" - self._restore_best_loss() - - self.total_steps_done = self.restore_step - - for epoch in range(0, self.config.epochs): - if self.num_gpus > 1: - # let all processes sync up before starting with a new epoch of training - dist.barrier() - self.callbacks.on_epoch_start(self) - self.keep_avg_train = KeepAverage() - self.keep_avg_eval = KeepAverage() if self.config.run_eval else None - self.epochs_done = epoch - self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path) - if not self.args.skip_train_epoch: - self.train_epoch() - if self.config.run_eval: - self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank <= 0: - self.test_run() - self.c_logger.print_epoch_end( - epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values - ) - if self.args.rank in [None, 0]: - self.save_best_model() - self.callbacks.on_epoch_end(self) - - def fit(self) -> None: - """Where the ✨️magic✨️ happens...""" - try: - self._fit() - if self.args.rank == 0: - self.dashboard_logger.finish() - except KeyboardInterrupt: - self.callbacks.on_keyboard_interrupt(self) - # if the output folder is empty remove the run. - remove_experiment_folder(self.output_path) - # clear the DDP processes - if self.num_gpus > 1: - dist.destroy_process_group() - # finish the wandb run and sync data - if self.args.rank == 0: - self.dashboard_logger.finish() - # stop without error signal - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except BaseException: # pylint: disable=broad-except - remove_experiment_folder(self.output_path) - traceback.print_exc() - sys.exit(1) - - def save_best_model(self) -> None: - """Save the best model. It only saves if the current target loss is smaller then the previous.""" - - # set the target loss to choose the best model - target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train) - - # save the model and update the best_loss - self.best_loss = save_best_model( - target_loss_dict, - self.best_loss, - self.config, - self.model, - self.optimizer, - self.scaler if self.use_amp_scaler else None, - self.total_steps_done, - self.epochs_done, - self.output_path, - keep_all_best=self.config.keep_all_best, - keep_after=self.config.keep_after, - ) - - ##################### - # GET FUNCTIONS - ##################### - - @staticmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: - """Receive the optimizer from the model if model implements `get_optimizer()` else - check the optimizer parameters in the config and try initiating the optimizer. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list. - """ - if hasattr(model, "get_optimizer"): - optimizer = model.get_optimizer() - if optimizer is None: - optimizer_name = config.optimizer - optimizer_params = config.optimizer_params - return get_optimizer(optimizer_name, optimizer_params, config.lr, model) - return optimizer - - @staticmethod - def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: - """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate - fromthe config. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier. - """ - lr = None - if hasattr(model, "get_lr"): - lr = model.get_lr() - if lr is None: - lr = config.lr - return lr - - @staticmethod - def get_scheduler( - model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List] - ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access - """Receive the scheduler from the model if model implements `get_scheduler()` else - check the config and try initiating the scheduler. - - Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. - - Returns: - Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer. - """ - scheduler = None - if hasattr(model, "get_scheduler"): - scheduler = model.get_scheduler(optimizer) - if scheduler is None: - lr_scheduler = config.lr_scheduler - lr_scheduler_params = config.lr_scheduler_params - return get_scheduler(lr_scheduler, lr_scheduler_params, optimizer) - return scheduler - - @staticmethod - def get_criterion(model: nn.Module) -> nn.Module: - """Receive the criterion from the model. Model must implement `get_criterion()`. - - Args: - model (nn.Module): Training model. - - Returns: - nn.Module: Criterion layer. - """ - criterion = None - criterion = model.get_criterion() - return criterion - - #################### - # HELPER FUNCTIONS - #################### - - @staticmethod - def _detach_loss_dict(loss_dict: Dict) -> Dict: - """Detach loss values from autograp. - - Args: - loss_dict (Dict): losses. - - Returns: - Dict: losses detached from autograph. - """ - loss_dict_detached = {} - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_detached[key] = value - else: - loss_dict_detached[key] = value.detach().item() - return loss_dict_detached - - def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: - """Pick the target loss to compare models""" - target_avg_loss = None - - # return if target loss defined in the model config - if "target_loss" in self.config and self.config.target_loss: - return keep_avg_target[f"avg_{self.config.target_loss}"] - - # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers - if isinstance(self.optimizer, list): - target_avg_loss = 0 - for idx in range(len(self.optimizer)): - target_avg_loss += keep_avg_target[f"avg_loss_{idx}"] - target_avg_loss /= len(self.optimizer) - else: - target_avg_loss = keep_avg_target["avg_loss"] - return target_avg_loss - - def _setup_logger_config(self, log_file: str) -> None: - """Write log strings to a file and print logs to the terminal. - TODO: Causes formatting issues in pdb debugging.""" - - class Logger(object): - def __init__(self, print_to_terminal=True): - self.print_to_terminal = print_to_terminal - self.terminal = sys.stdout - self.log_file = log_file - - def write(self, message): - if self.print_to_terminal: - self.terminal.write(message) - with open(self.log_file, "a", encoding="utf-8") as f: - f.write(message) - - def flush(self): - # this flush method is needed for python 3 compatibility. - # this handles the flush command by doing nothing. - # you might want to specify some extra behavior here. - pass - - # don't let processes rank > 0 write to the terminal - sys.stdout = Logger(self.args.rank == 0) - - @staticmethod - def _is_apex_available() -> bool: - """Check if Nvidia's APEX is available.""" - return importlib.util.find_spec("apex") is not None diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 8f063102..024040f8 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -89,11 +89,11 @@ class FastPitchConfig(BaseTTSConfig): pitch_loss_alpha (float): Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. - binary_loss_alpha (float): + binary_align_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -129,12 +129,12 @@ class FastPitchConfig(BaseTTSConfig): duration_loss_type: str = "mse" use_ssim_loss: bool = True ssim_loss_alpha: float = 1.0 - dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 - pitch_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 - binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 20000 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py index 31d99442..16a76e21 100644 --- a/TTS/tts/configs/fast_speech_config.py +++ b/TTS/tts/configs/fast_speech_config.py @@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig): pitch_loss_alpha: float = 0.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 20000 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py index ce8eee6d..f42f3e5a 100644 --- a/TTS/tts/configs/glow_tts_config.py +++ b/TTS/tts/configs/glow_tts_config.py @@ -153,6 +153,7 @@ class GlowTTSConfig(BaseTTSConfig): # multi-speaker settings use_speaker_embedding: bool = False + speakers_file: str = None use_d_vector_file: bool = False d_vector_file: str = False diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 60ef7276..f43c6464 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import List +from typing import Dict, List from coqpit import Coqpit, check_argument @@ -50,9 +50,16 @@ class GSTConfig(Coqpit): @dataclass class CharactersConfig(Coqpit): - """Defines character or phoneme set used by the model + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + pad (str): characters in place of empty padding. Defaults to None. @@ -62,6 +69,9 @@ class CharactersConfig(Coqpit): bos (str): characters showing the beginning of a sentence. Defaults to None. + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + characters (str): character set used by the model. Characters not in this list are ignored when converting input text to a list of sequence IDs. Defaults to None. @@ -70,32 +80,32 @@ class CharactersConfig(Coqpit): characters considered as punctuation as parsing the input sentence. Defaults to None. phonemes (str): - characters considered as parsing phonemes. Defaults to None. + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. - unique (bool): + is_unique (bool): remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old - models trained with character lists with duplicates. + models trained with character lists with duplicates. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. """ + characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters pad: str = None eos: str = None bos: str = None + blank: str = None characters: str = None punctuations: str = None phonemes: str = None - unique: bool = True # for backwards compatibility of models trained with char sets with duplicates - - def check_values( - self, - ): - """Check config fields""" - c = asdict(self) - check_argument("pad", c, prerequest="characters", restricted=True) - check_argument("eos", c, prerequest="characters", restricted=True) - check_argument("bos", c, prerequest="characters", restricted=True) - check_argument("characters", c, prerequest="characters", restricted=True) - check_argument("phonemes", c, restricted=True) - check_argument("punctuations", c, prerequest="characters", restricted=True) + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True @dataclass @@ -110,8 +120,13 @@ class BaseTTSConfig(BaseTrainingConfig): use_phonemes (bool): enable / disable phoneme use. - use_espeak_phonemes (bool): - enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`). + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. compute_input_seq_cache (bool): enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of @@ -144,11 +159,19 @@ class BaseTTSConfig(BaseTrainingConfig): sort_by_audio_len (bool): If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. - min_seq_len (int): - Minimum sequence length to be used at training. + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. - max_seq_len (int): - Maximum sequence length to be used at training. Larger values result in more VRAM usage. + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). compute_f0 (int): (Not in use yet). @@ -156,9 +179,16 @@ class BaseTTSConfig(BaseTrainingConfig): compute_linear_spec (bool): If True data loader computes and returns linear spectrograms alongside the other data. + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + use_noise_augment (bool): Augment the input audio with random noise. + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + add_blank (bool): Add blank characters between each other two characters. It improves performance for some models at expense of slower run-time due to the longer input sequence. @@ -183,12 +213,19 @@ class BaseTTSConfig(BaseTrainingConfig): test_sentences (List[str]): List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) # phoneme settings use_phonemes: bool = False - use_espeak_phonemes: bool = True + phonemizer: str = None phoneme_language: str = None compute_input_seq_cache: bool = False text_cleaner: str = None @@ -197,17 +234,21 @@ class BaseTTSConfig(BaseTrainingConfig): phoneme_cache_path: str = None # vocabulary parameters characters: CharactersConfig = None + add_blank: bool = False # training params batch_group_size: int = 0 loss_masking: bool = None # dataloading sort_by_audio_len: bool = False - min_seq_len: int = 1 - max_seq_len: int = float("inf") + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") compute_f0: bool = False compute_linear_spec: bool = False + precompute_num_workers: int = 0 use_noise_augment: bool = False - add_blank: bool = False + start_by_longest: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer @@ -218,3 +259,6 @@ class BaseTTSConfig(BaseTrainingConfig): lr_scheduler_params: dict = field(default_factory=lambda: {}) # testing test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index ea6866ed..4bf5101f 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig): spec_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 0.3 - binary_align_loss_start_step: int = 50000 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index d6edd267..5193c224 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -83,6 +83,8 @@ class TacotronConfig(BaseTTSConfig): ddc_r (int): reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. use_speaker_embedding (bool): enable / disable using speaker embeddings for multi-speaker models. If set True, the model is in the multi-speaker mode. Defaults to False. @@ -176,6 +178,7 @@ class TacotronConfig(BaseTTSConfig): ddc_r: int = 6 # multi-speaker settings + speakers_file: str = None use_speaker_embedding: bool = False speaker_embedding_dim: int = 512 use_d_vector_file: bool = False diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 36c948af..a8c7f91d 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -17,7 +17,7 @@ class VitsConfig(BaseTTSConfig): Model architecture arguments. Defaults to `VitsArgs()`. grad_clip (List): - Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`. + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. lr_gen (float): Initial learning rate for the generator. Defaults to 0.0002. @@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. - sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. - - min_seq_len (int): - Minimum sequnce length to be considered for training. Defaults to `0`. - - max_seq_len (int): - Maximum sequnce length to be considered for training. Defaults to `500000`. - r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. @@ -130,9 +121,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec: bool = True # overrides - sort_by_audio_len: bool = True - min_seq_len: int = 0 - max_seq_len: int = 500000 r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 40eed7e3..6c7c9edd 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -9,25 +9,48 @@ from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * -def split_dataset(items): +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. - Args: - items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + Args: + <<<<<<< HEAD + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + ======= + items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. + >>>>>>> Fix docstring """ - speakers = [item[-1] for item in items] + speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = min(500, int(len(items) * 0.01)) - assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: items_eval = [] - speakers = [item[-1] for item in items] + speakers = [item["speaker_name"] for item in items] speaker_counter = Counter(speakers) while len(items_eval) < eval_split_size: item_idx = np.random.randint(0, len(items)) - speaker_to_be_removed = items[item_idx][-1] + speaker_to_be_removed = items[item_idx]["speaker_name"] if speaker_counter[speaker_to_be_removed] > 1: items_eval.append(items[item_idx]) speaker_counter[speaker_to_be_removed] -= 1 @@ -37,7 +60,11 @@ def split_dataset(items): def load_tts_samples( - datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, ) -> Tuple[List[List], List[List]]: """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based @@ -52,9 +79,16 @@ def load_tts_samples( formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It must take the root_path and the meta_file name and return a list of samples in the format of - `[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as example. Defaults to None. + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + Returns: Tuple[List[List], List[List]: training and evaluation splits of the dataset. """ @@ -75,28 +109,28 @@ def load_tts_samples( formatter = _get_formatter_by_name(name) # load train set meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) - meta_data_train = [[*item, language] for item in meta_data_train] + meta_data_train = [{**item, **{"language": language}} for item in meta_data_train] print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [[*item, language] for item in meta_data_eval] + meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] else: - meta_data_eval, meta_data_train = split_dataset(meta_data_train) + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): - attn_file = meta_data[ins[1]].strip() - meta_data_train_all[idx].append(attn_file) + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): - attn_file = meta_data[ins[1]].strip() - meta_data_eval_all[idx].append(attn_file) + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) # set none for the next iter formatter = None return meta_data_train_all, meta_data_eval_all diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 2f20c865..d8f16e4e 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,8 +1,7 @@ import collections import os import random -from multiprocessing import Pool -from typing import Dict, List +from typing import Dict, List, Union import numpy as np import torch @@ -10,87 +9,99 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.utils.audio import AudioProcessor +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy("file_system") + + +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + class TTSDataset(Dataset): def __init__( self, - outputs_per_step: int, - text_cleaner: list, - compute_linear_spec: bool, - ap: AudioProcessor, - meta_data: List[List], + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: List[Dict] = None, + tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, - characters: Dict = None, - custom_symbols: List = None, - add_blank: bool = False, return_wav: bool = False, batch_group_size: int = 0, - min_seq_len: int = 0, - max_seq_len: int = float("inf"), - use_phonemes: bool = False, + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), phoneme_cache_path: str = None, - phoneme_language: str = "en-us", - enable_eos_bos: bool = False, + precompute_num_workers: int = 0, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, use_noise_augment: bool = False, + start_by_longest: bool = False, verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can inherit and override. + If you need something different, you can subclass and override. Args: outputs_per_step (int): Number of time frames predicted per step. - text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs. - compute_linear_spec (bool): compute linear spectrogram if True. ap (TTS.tts.utils.AudioProcessor): Audio processor object. - meta_data (list): List of dataset instances. + samples (list): List of dataset samples. + + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. compute_f0 (bool): compute f0 if True. Defaults to False. f0_cache_path (str): Path to store f0 cache. Defaults to None. - characters (dict): `dict` of custom text characters used for converting texts to sequences. - - custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own - set of symbols need to pass it here. Defaults to `None`. - - add_blank (bool): Add a special `blank` character after every other character. It helps some - models achieve better results. Defaults to false. - return_wav (bool): Return the waveform of the sample. Defaults to False. batch_group_size (int): Range of batch randomization after sorting sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a batch. Set 0 to disable. Defaults to 0. - min_seq_len (int): Minimum input sequence length to be processed - by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a - minimum input length due to its architecture. Defaults to 0. + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. - max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. - It helps for controlling the VRAM usage against long input sequences. Especially models with - RNN layers are sensitive to input length. Defaults to `Inf`. + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). - use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. - phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. - - enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults - to False. + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. @@ -99,285 +110,254 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size - self.items = meta_data + self._samples = samples self.outputs_per_step = outputs_per_step - self.sample_rate = ap.sample_rate - self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 self.f0_cache_path = f0_cache_path - self.min_seq_len = min_seq_len - self.max_seq_len = max_seq_len + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len self.ap = ap - self.characters = characters - self.custom_symbols = custom_symbols - self.add_blank = add_blank - self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path - self.phoneme_language = phoneme_language - self.enable_eos_bos = enable_eos_bos self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest self.verbose = verbose - self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False + self.tokenizer = tokenizer + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) - if use_phonemes and not os.path.isdir(phoneme_cache_path): - os.makedirs(phoneme_cache_path, exist_ok=True) if compute_f0: - self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) + if self.verbose: - print("\n > DataLoader initialization") - print(" | > Use phonemes: {}".format(self.use_phonemes)) - if use_phonemes: - print(" | > phoneme language: {}".format(phoneme_language)) - print(" | > Number of instances : {}".format(len(self.items))) + self.print_logs() + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - audio = self.ap.load_wav(filename) - return audio + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform + + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert len(out_dict["token_ids"]) > 0 + return out_dict + + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + item = self.samples[idx] + assert item["audio_file"] == out_dict["audio_file"] + return out_dict @staticmethod - def load_np(filename): - data = np.load(filename).astype("float32") - return data + def get_attn_mask(attn_file): + return np.load(attn_file) - @staticmethod - def _generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ): - """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and - eos chars here. Instead we add those dynamically later; based on the - config option.""" - phonemes = phoneme_to_sequence( - text, - [cleaners], - language=language, - enable_eos_bos=False, - custom_symbols=custom_symbols, - tp=characters, - add_blank=add_blank, - ) - phonemes = np.asarray(phonemes, dtype=np.int32) - np.save(cache_path, phonemes) - return phonemes - - @staticmethod - def _load_or_generate_phoneme_sequence( - wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank - ): - file_name = os.path.splitext(os.path.basename(wav_file))[0] - - # different names for normal phonemes and with blank chars. - file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" - cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) - try: - phonemes = np.load(cache_path) - except FileNotFoundError: - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) - except (ValueError, IOError): - print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) - if enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes, tp=characters) - phonemes = np.asarray(phonemes, dtype=np.int32) - return phonemes + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return np.array(token_ids, dtype=np.int32) def load_data(self, idx): - item = self.items[idx] + item = self.samples[idx] - if len(item) == 5: - text, wav_file, speaker_name, language_name, attn_file = item - else: - text, wav_file, speaker_name, language_name = item - attn = None - raw_text = text + raw_text = item["text"] - wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: - wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + wav = noise_augment_audio(wav) - if not self.input_seq_computed: - if self.use_phonemes: - text = self._load_or_generate_phoneme_sequence( - wav_file, - text, - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, - language_name if language_name else self.phoneme_language, - self.custom_symbols, - self.characters, - self.add_blank, - ) - else: - text = np.asarray( - text_to_sequence( - text, - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), - dtype=np.int32, - ) + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) - assert text.size > 0, self.items[idx][1] - assert wav.size > 0, self.items[idx][1] + # get pre-computed attention maps + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) - if "attn_file" in locals(): - attn = np.load(attn_file) - - if len(text) > self.max_seq_len: - # return a different sample if the phonemized - # text is longer than the threshold - # TODO: find a better fix + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 return self.load_data(self.rescue_item_idx) - pitch = None + # get f0 values + f0 = None if self.compute_f0: - pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) - pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) + f0 = self.get_f0(idx)["f0"] sample = { "raw_text": raw_text, - "text": text, + "token_ids": token_ids, "wav": wav, - "pitch": pitch, + "pitch": f0, "attn": attn, - "item_idx": self.items[idx][1], - "speaker_name": speaker_name, - "language_name": language_name, - "wav_file_name": os.path.basename(wav_file), + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), } return sample @staticmethod - def _phoneme_worker(args): - item = args[0] - func_args = args[1] - text, wav_file, *_ = item - func_args[3] = ( - item[3] if item[3] else func_args[3] - ) # override phoneme language if specified by the dataset formatter - phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) - return phonemes + def _compute_lengths(samples): + new_samples = [] + for item in samples: + audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] + return new_samples - def compute_input_seq(self, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not self.use_phonemes: - if self.verbose: - print(" | > Computing input sequences ...") - for idx, item in enumerate(tqdm.tqdm(self.items)): - text, *_ = item - sequence = np.asarray( - text_to_sequence( - text, - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), - dtype=np.int32, - ) - self.items[idx][0] = sequence - - else: - func_args = [ - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, - self.phoneme_language, - self.custom_symbols, - self.characters, - self.add_blank, - ] - if self.verbose: - print(" | > Computing phonemes ...") - if num_workers == 0: - for idx, item in enumerate(tqdm.tqdm(self.items)): - phonemes = self._phoneme_worker([item, func_args]) - self.items[idx][0] = phonemes + @staticmethod + def filter_by_length(lengths: List[int], min_len: int, max_len: int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] + for idx in idxs: + length = lengths[idx] + if length < min_len or length > max_len: + ignore_idx.append(idx) else: - with Pool(num_workers) as p: - phonemes = list( - tqdm.tqdm( - p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]), - total=len(self.items), - ) - ) - for idx, p in enumerate(phonemes): - self.items[idx][0] = p + keep_idx.append(idx) + return ignore_idx, keep_idx - def sort_and_filter_items(self, by_audio_len=False): + @staticmethod + def sort_by_length(samples: List[List]): + audio_lengths = [s["audio_length"] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order + return idxs + + @staticmethod + def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + @staticmethod + def _select_samples_by_idx(idxs, samples): + samples_new = [] + for idx in idxs: + samples_new.append(samples[idx]) + return samples_new + + def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. - - Args: - by_audio_len (bool): if True, sort by audio length else by text length. """ - # compute the target sequence length - if by_audio_len: - lengths = [] - for item in self.items: - lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio - lengths = np.array(lengths) - else: - lengths = np.array([len(ins[0]) for ins in self.items]) + samples = self._compute_lengths(self.samples) + + # sort items based on the sequence length in ascending order + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = self._select_samples_by_idx(keep_idx, samples) + + sorted_idxs = self.sort_by_length(samples) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + + samples = self._select_samples_by_idx(sorted_idxs, samples) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") - idxs = np.argsort(lengths) - new_items = [] - ignored = [] - for i, idx in enumerate(idxs): - length = lengths[idx] - if length < self.min_seq_len or length > self.max_seq_len: - ignored.append(idx) - else: - new_items.append(self.items[idx]) # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: - for i in range(len(new_items) // self.batch_group_size): - offset = i * self.batch_group_size - end_offset = offset + self.batch_group_size - temp_items = new_items[offset:end_offset] - random.shuffle(temp_items) - new_items[offset:end_offset] = temp_items - self.items = new_items + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples if self.verbose: - print(" | > Max length sequence: {}".format(np.max(lengths))) - print(" | > Min length sequence: {}".format(np.min(lengths))) - print(" | > Avg length sequence: {}".format(np.mean(lengths))) - print( - " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( - self.max_seq_len, self.min_seq_len, len(ignored) - ) - ) + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size)) - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.load_data(idx) - @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. @@ -402,10 +382,10 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lengths = np.array([len(d["text"]) for d in batch]) + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) # sort items with text input length for RNN efficiency - batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} @@ -447,7 +427,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(batch["text"]).astype(np.int32) + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -456,12 +436,13 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lengths = torch.LongTensor(text_lengths) - text = torch.LongTensor(text) + token_ids_lengths = torch.LongTensor(token_ids_lengths) + token_ids = torch.LongTensor(token_ids) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) + # speaker vectors if d_vectors is not None: d_vectors = torch.FloatTensor(d_vectors) @@ -472,14 +453,13 @@ class TTSDataset(Dataset): language_ids = torch.LongTensor(language_ids) # compute linear spectrogram + linear = None if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] linear = torch.FloatTensor(linear).contiguous() - else: - linear = None # format waveforms wav_padded = None @@ -495,8 +475,7 @@ class TTSDataset(Dataset): wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) - # compute f0 - # TODO: compare perf in collate_fn vs in load_data + # format F0 if self.compute_f0: pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" @@ -504,23 +483,22 @@ class TTSDataset(Dataset): else: pitch = None - # collate attention alignments + # format attention masks + attns = None if batch["attn"][0] is not None: attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] - pad1 = text.shape[1] - attn.shape[0] + pad1 = token_ids.shape[1] - attn.shape[0] assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) attns = torch.FloatTensor(attns).unsqueeze(1) - else: - attns = None - # TODO: return dictionary + return { - "text": text, - "text_lengths": text_lengths, + "token_id": token_ids, + "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, @@ -546,22 +524,185 @@ class TTSDataset(Dataset): ) -class PitchExtractor: - """Pitch Extractor for computing F0 from wav files. +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + Args: - items (List[List]): Dataset samples. - verbose (bool): Whether to print the progress. + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. """ def __init__( self, - items: List[List], - verbose=False, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, ): - self.items = items + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + item = self.samples[index] + ids = self.compute_or_load(item["audio_file"], item["text"]) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, wav_file, text): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_name = os.path.splitext(os.path.basename(wav_file))[0] + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + print("[*] Pre-computing phonemes...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 self.mean = None self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"]) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_file": item["audio_file"], "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing F0s...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append(f for f in f0) + pbar.update(batch_size) + self.normalize_f0 = normalize_f0 + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id @staticmethod def create_pitch_file_path(wav_file, cache_path): @@ -583,70 +724,49 @@ class PitchExtractor: mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def normalize_pitch(self, pitch): + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch = pitch - self.mean pitch = pitch / self.std pitch[zero_idxs] = 0.0 return pitch - def denormalize_pitch(self, pitch): + def denormalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch *= self.std pitch += self.mean pitch[zero_idxs] = 0.0 return pitch - @staticmethod - def load_or_compute_pitch(ap, wav_file, cache_path): + def compute_or_load(self, wav_file): """ compute pitch and return a numpy array of pitch values """ - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - _, wav_file, *_ = item - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) - return pitch - return None + def collate_fn(self, batch): + audio_file = [item["audio_file"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} - def compute_pitch(self, ap, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) - - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 1f23f85e..aacfc647 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" wav_file = os.path.join(root_path, folder_name, wav_file) - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) else: # M-AI-Labs have some missing samples, so just print the warning print("> File %s does not exist!" % (wav_file)) @@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg txt_file = os.path.join(root_path, meta_file) items = [] with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx % 2 == 0: + speaker_id += 1 cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append([text, wav_file, f"ljspeech-{idx}"]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) return items @@ -150,7 +154,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg if not os.path.exists(wav_file): print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -165,7 +169,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -179,7 +183,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[1] - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -193,7 +197,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument utt_id = line.split()[1] text = line[line.find('"') + 1 : line.rfind('"') - 1] wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items @@ -213,7 +217,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None): if speaker_name in ignored_speakers: continue wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append([text, wav_file, "MCV_" + speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) return items @@ -240,7 +244,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - items.append([text, wav_file, "LTTS_" + speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) for item in items: assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -259,7 +263,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar skipped_files.append(wav_file) continue text = cols[1].strip() - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items @@ -281,12 +285,32 @@ def brspeech(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append([text, wav_file, speaker_id]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) return items -def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): - """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "flac" items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: @@ -298,26 +322,33 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): continue with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] - wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append([text, wav_file, "VCTK_" + speaker_id]) - + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append([text, wav_file, "VCTK_" + speaker_id]) + else: + print(f" [!] wav files don't exist - {wav_file}") return items -def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument +def vctk_old(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + test_speakers = meta_files items = [] - txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) - for text_file in txt_files: - _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) file_id = txt_file.split(".")[0] - # ignore speakers - if isinstance(ignored_speakers, list): - if speaker_id in ignored_speakers: + if isinstance(test_speakers, list): # if is list ignore this speakers ids + if speaker_id in test_speakers: continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append([None, wav_file, "VCTK_" + speaker_id]) - + items.append([text, wav_file, "VCTK_old_" + speaker_id]) return items @@ -334,7 +365,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker in ignored_speakers: continue - items.append([text, wav_file, "MLS_" + speaker]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) return items @@ -404,7 +435,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) - items.append([text, wav_path, speaker_name]) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) return items @@ -418,5 +449,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2].replace(" ", "") - items.append([text, wav_file, speaker_name]) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) return items diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index 4766c77d..c0270e40 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -113,7 +113,7 @@ class ActNorm(nn.Module): denom = torch.sum(x_mask, [0, 2]) m = torch.sum(x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m ** 2) + v = m_sq - (m**2) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 0c87e9df..aeb45c7b 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -65,7 +65,7 @@ class WN(torch.nn.Module): self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") # intermediate layers for i in range(num_layers): - dilation = dilation_rate ** i + dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) in_layer = torch.nn.Conv1d( hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 36ed668b..3b43e527 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -101,7 +101,7 @@ class Encoder(nn.Module): self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index ba6aa1e2..0f837abf 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module): # relative positional encoding layers if rel_attn_window_size is not None: n_heads_rel = 1 if heads_share else num_heads - rel_stddev = self.k_channels ** -0.5 + rel_stddev = self.k_channels**-0.5 emb_rel_k = nn.Parameter( torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev ) @@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module): batch, heads, length, _ = x.size() # padd along column x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) # add 0's in the beginning that will skew the elements after reshape x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 7de45041..e03cf084 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module): def _make_ga_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() - return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) @staticmethod def _make_masks(ilens, olens): @@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module): rl = rl.float().detach() gl = gl.float() loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 @staticmethod @@ -588,13 +587,12 @@ class VitsGeneratorLoss(nn.Module): @staticmethod def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): - l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() - return l + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() def forward( self, - waveform, - waveform_hat, + mel_slice, + mel_slice_hat, z_p, logs_q, m_p, @@ -610,8 +608,8 @@ class VitsGeneratorLoss(nn.Module): ): """ Shapes: - - waveform : :math:`[B, 1, T]` - - waveform_hat: :math:`[B, 1, T]` + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` - m_p: :math:`[B, C, T]` @@ -624,23 +622,23 @@ class VitsGeneratorLoss(nn.Module): loss = 0.0 return_dict = {} z_mask = sequence_mask(z_len).float() - # compute mel spectrograms from the waveforms - mel = self.stft(waveform) - mel_hat = self.stft(waveform_hat) - # compute losses - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha - loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha - loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha - loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha - loss += loss_se + loss = loss + loss_se return_dict["loss_spk_encoder"] = loss_se - # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl @@ -665,20 +663,24 @@ class VitsDiscriminatorLoss(nn.Module): dr = dr.float() dg = dg.float() real_loss = torch.mean((1 - dr) ** 2) - fake_loss = torch.mean(dg ** 2) + fake_loss = torch.mean(dg**2) loss += real_loss + fake_loss real_losses.append(real_loss.item()) fake_losses.append(fake_loss.item()) - return loss, real_losses, fake_losses def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr return return_dict @@ -740,6 +742,7 @@ class ForwardTTSLoss(nn.Module): alignment_logprob=None, alignment_hard=None, alignment_soft=None, + binary_loss_weight=None, ): loss = 0 return_dict = {} @@ -772,7 +775,12 @@ class ForwardTTSLoss(nn.Module): if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss - return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 01a81e0b..7d751bc0 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module): # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] - scores = scores / (self.key_dim ** 0.5) + scores = scores / (self.key_dim**0.5) scores = F.softmax(scores, dim=3) # out = score * V diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index 9c33623e..c79b7099 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -6,7 +6,6 @@ from .attentions import init_attn from .common_layers import Linear, Prenet -# NOTE: linter has a problem with the current TF release # pylint: disable=no-value-for-parameter # pylint: disable=unexpected-keyword-arg class ConvBNBlock(nn.Module): diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index ef426ace..f97b584f 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -57,7 +57,7 @@ class TextEncoder(nn.Module): self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) if language_emb_dim: hidden_channels += language_emb_dim @@ -83,6 +83,7 @@ class TextEncoder(nn.Module): - x: :math:`[B, T]` - x_length: :math:`[B]` """ + assert x.shape[0] == x_lengths.shape[0] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] # concat the lang emb in embedding chars @@ -90,7 +91,7 @@ class TextEncoder(nn.Module): x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask @@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` @@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index 120d0944..738ee341 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module): self.norms_1 = nn.ModuleList() self.norms_2 = nn.ModuleList() for i in range(num_layers): - dilation = kernel_size ** i + dilation = kernel_size**i padding = (kernel_size * dilation - dilation) // 2 self.convs_sep.append( nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) @@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module): # posterior encoder - neg log likelihood logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) nll_posterior_encoder = ( - torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q ) z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask @@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module): z = torch.flip(z, [1]) # flow layers - neg log likelihood - nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot return nll_flow_layers + nll_posterior_encoder flows = list(reversed(self.flows)) diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 4cc8b658..d76a3beb 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,52 +1,14 @@ -from TTS.tts.utils.text.symbols import make_symbols, parse_symbols +from typing import Dict, List, Union + from TTS.utils.generic_utils import find_module -def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) else: MyModel = find_module("TTS.tts.models", config.model.lower()) - # define set of characters used by the model - if config.characters is not None: - # set characters from config - if hasattr(MyModel, "make_symbols"): - symbols = MyModel.make_symbols(config) - else: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel - - if config.use_phonemes: - symbols = phonemes - # use default characters and assign them to config - config.characters = parse_symbols() - # consider special `blank` character if `add_blank` is set True - num_chars = len(symbols) + getattr(config, "add_blank", False) - config.num_chars = num_chars - # compatibility fix - if "model_params" in config: - config.model_params.num_chars = num_chars - if "model_args" in config: - config.model_args.num_chars = num_chars - if config.model.lower() in ["vits"]: # If model supports multiple languages - model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) - else: - model = MyModel(config, speaker_manager=speaker_manager) + model = MyModel.init_from_config(config, samples) return model - - -# TODO; class registery -# def import_models(models_dir, namespace): -# for file in os.listdir(models_dir): -# path = os.path.join(models_dir, file) -# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): -# model_name = file[: file.find(".py")] if file.endswith(".py") else file -# importlib.import_module(namespace + "." + model_name) -# -# -## automatically import any Python files in the models/ directory -# models_dir = os.path.dirname(__file__) -# import_models(models_dir, "TTS.tts.models") diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2fc00b0b..c1e2ffb3 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Dict, List, Union import torch from coqpit import Coqpit @@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -100,11 +102,16 @@ class AlignTTS(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) self.speaker_manager = speaker_manager - self.config = config self.phase = -1 self.length_scale = ( float(config.model_args.length_scale) @@ -112,10 +119,6 @@ class AlignTTS(BaseTTS): else config.model_args.length_scale ) - if not self.config.model_args.num_chars: - _, self.config, num_chars = self.get_characters(config) - self.config.model_args.num_chars = num_chars - self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.embedded_speaker_dim = 0 @@ -382,19 +385,17 @@ class AlignTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -430,3 +431,19 @@ class AlignTTS(BaseTTS): def on_epoch_start(self, trainer): """Set AlignTTS training phase on epoch start.""" self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index ca8f3bb9..54939c61 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -9,6 +9,8 @@ from torch import nn from TTS.tts.layers.losses import TacotronLoss from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler @@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" - def __init__(self, config: Coqpit): - super().__init__(config) + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields as class attributes for key in config: @@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS): """Get the model criterion used in training.""" return TacotronLoss(self.config) + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + ############################# # COMMON COMPUTE FUNCTIONS ############################# diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index e52cd765..4e54b947 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,6 +1,6 @@ import os import random -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -9,33 +9,44 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.model import BaseModel -from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file -class BaseTTS(BaseModel): +class BaseTTS(BaseTrainerModel): """Base `tts` class. Every new `tts` model must inherit this. It defines common `tts` specific functions on top of `Model` implementation. - - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as - - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` """ + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + def _set_model_args(self, config: Coqpit): - """Setup model args based on the config type. + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. If the config is for training with a name like "*Config", then the model args are embeded in the config.model_args @@ -44,8 +55,11 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars + ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars if "characters" in config: - _, self.config, num_chars = self.get_characters(config) self.config.num_chars = num_chars if hasattr(self.config, "model_args"): config.model_args.num_chars = num_chars @@ -58,22 +72,6 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - @staticmethod - def get_characters(config: Coqpit) -> str: - # TODO: implement CharacterProcessor - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols - - config.characters = CharactersConfig(**parse_symbols()) - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) + getattr(config, "add_blank", False) - return model_characters, config, num_chars - - def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: - return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit, data: List = None): """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -170,8 +168,8 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch["text"] - text_lengths = batch["text_lengths"] + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] speaker_names = batch["speaker_names"] linear_input = batch["linear"] mel_input = batch["mel"] @@ -239,7 +237,7 @@ class BaseTTS(BaseModel): config: Coqpit, assets: Dict, is_eval: bool, - data_items: List, + samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, @@ -247,8 +245,6 @@ class BaseTTS(BaseModel): if is_eval and not config.run_eval: loader = None else: - ap = assets["audio_processor"] - # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): @@ -264,12 +260,8 @@ class BaseTTS(BaseModel): speaker_id_mapping = None d_vector_mapping = None - # setup custom symbols if needed - custom_symbols = None - if hasattr(self, "make_symbols"): - custom_symbols = self.make_symbols(self.config) - - if hasattr(self, "language_manager"): + # setup multi-lingual attributes + if hasattr(self, "language_manager") and self.language_manager is not None: language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None ) @@ -279,74 +271,40 @@ class BaseTTS(BaseModel): # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, - text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), - meta_data=data_items, - ap=ap, - characters=config.characters, - custom_symbols=custom_symbols, - add_blank=config["add_blank"], + samples=samples, + ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, + precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, ) - # pre-compute phonemes - if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]: - if hasattr(self, "eval_data_items") and is_eval: - dataset.items = self.eval_data_items - elif hasattr(self, "train_data_items") and not is_eval: - dataset.items = self.train_data_items - else: - # precompute phonemes for precise estimate of sequence lengths. - # otherwise `dataset.sort_items()` uses raw text lengths - dataset.compute_input_seq(config.num_loader_workers) - - # TODO: find a more efficient solution - # cheap hack - store items in the model state to avoid recomputing when reinit the dataset - if is_eval: - self.eval_data_items = dataset.items - else: - self.train_data_items = dataset.items - - # halt DDP processes for the main process to finish computing the phoneme cache + # wait all the DDP process to be ready if num_gpus > 1: dist.barrier() # sort input sequences from short to long - dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) - - # compute pitch frames and write to files. - if config.compute_f0 and rank in [None, 0]: - if not os.path.exists(config.f0_cache_path): - dataset.pitch_extractor.compute_pitch( - ap, config.get("f0_cache_path", None), config.num_loader_workers - ) - - # halt DDP processes for the main process to finish computing the F0 cache - if num_gpus > 1: - dist.barrier() - - # load pitch stats computed above by all the workers - if config.compute_f0: - dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.preprocess_samples() # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None # Weighted samplers + # TODO: make this DDP amenable assert not ( num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) ), "language_weighted_sampler is not supported with DistributedSampler" @@ -357,17 +315,17 @@ class BaseTTS(BaseModel): if sampler is None: if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.items) + sampler = get_language_weighted_sampler(dataset.samples) elif getattr(config, "use_speaker_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.items) + sampler = get_speaker_weighted_sampler(dataset.samples) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, + shuffle=False, # shuffle is done in the dataset. collate_fn=dataset.collate_fn, - drop_last=False, + drop_last=False, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, @@ -403,7 +361,6 @@ class BaseTTS(BaseModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -415,17 +372,15 @@ class BaseTTS(BaseModel): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs_dict["outputs"]["model_outputs"], ap, output_fig=False + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment( outputs_dict["outputs"]["alignments"], output_fig=False diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b2c41df5..a1273f7f 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -14,7 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager -from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram @dataclass @@ -170,17 +171,22 @@ class ForwardTTS(BaseTTS): """ # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) - super().__init__(config) - - self.speaker_manager = speaker_manager self.init_multispeaker(config) self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner self.use_pitch = self.args.use_pitch - self.use_binary_alignment_loss = False + self.binary_loss_weight = 0.0 self.length_scale = ( float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale @@ -255,7 +261,7 @@ class ForwardTTS(BaseTTS): # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") - self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels) + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod @@ -638,8 +644,9 @@ class ForwardTTS(BaseTTS): pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, - alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, - alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight, ) # compute duration error durations_pred = outputs["durations"] @@ -666,17 +673,12 @@ class ForwardTTS(BaseTTS): # plot pitch figures if self.args.use_pitch: - pitch = batch["pitch"] - pitch_avg_expanded, _ = self.expand_encoder_outputs( - outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] - ) - pitch = pitch[0, 0].data.cpu().numpy() - # TODO: denormalize before plotting - pitch = abs(pitch) - pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) pitch_figures = { - "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), - "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), } figures.update(pitch_figures) @@ -692,19 +694,17 @@ class ForwardTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -721,6 +721,21 @@ class ForwardTTS(BaseTTS): return ForwardTTSLoss(self.config) def on_train_step_start(self, trainer): - """Enable binary alignment loss when needed""" - if trainer.total_steps_done > self.config.binary_align_loss_start_step: - self.use_binary_alignment_loss = True + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index c1e4c2ac..fea570a6 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,5 +1,5 @@ import math -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit @@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -39,18 +40,31 @@ class GlowTTS(BaseTTS): Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS(config) - + >>> model = GlowTTS.init_from_config(config, verbose=False) """ - def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - super().__init__(config) - - self.speaker_manager = speaker_manager + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields to `self` # for fewer code change @@ -58,7 +72,6 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - _, self.config, self.num_chars = self.get_characters(config) self.decoder_output_dim = config.out_channels # init multi-speaker layers if necessary @@ -94,25 +107,25 @@ class GlowTTS(BaseTTS): def init_multispeaker(self, config: Coqpit): """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding - vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. Args: config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 - # init speaker manager - if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager if self.speaker_manager is not None: self.num_speakers = self.speaker_manager.num_speakers # set ultimate speaker embedding size - if config.use_speaker_embedding or config.use_d_vector_file: + if config.use_d_vector_file: self.embedded_speaker_dim = ( config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) + if self.speaker_manager is not None: + assert ( + config.d_vector_dim == self.speaker_manager.d_vector_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") @@ -170,6 +183,8 @@ class GlowTTS(BaseTTS): if g is not None: if hasattr(self, "emb_g"): # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] else: # use d-vector @@ -180,12 +195,33 @@ class GlowTTS(BaseTTS): self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ - Shapes: - - x: :math:`[B, T]` - - x_lenghts::math:`B` - - y: :math:`[B, T, C]` - - y_lengths::math:`B` - - g: :math:`[B, C] or B` + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) @@ -206,9 +242,9 @@ class GlowTTS(BaseTTS): with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) @@ -255,9 +291,9 @@ class GlowTTS(BaseTTS): # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() @@ -422,20 +458,18 @@ class GlowTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: @@ -446,7 +480,6 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -461,18 +494,16 @@ class GlowTTS(BaseTTS): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs["outputs"]["model_outputs"], ap, output_fig=False + outputs["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return test_figures, test_audios @@ -499,7 +530,8 @@ class GlowTTS(BaseTTS): self.store_inverse() assert not self.training - def get_criterion(self): + @staticmethod + def get_criterion(): from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel return GlowTTSLoss() @@ -507,3 +539,20 @@ class GlowTTS(BaseTTS): def on_train_step_start(self, trainer): """Decide on every training step wheter enable/disable data depended initialization.""" self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 4e46d252..8341f5bb 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -1,7 +1,8 @@ # coding: utf-8 +from typing import Dict, List, Union + import torch -from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -24,12 +26,15 @@ class Tacotron(BaseTacotron): a multi-speaker model. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): - super().__init__(config) + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): - self.speaker_manager = speaker_manager - chars, self.config, _ = self.get_characters(config) - config.num_chars = self.num_chars = len(chars) + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields to `self` # for fewer code change @@ -302,16 +307,30 @@ class Tacotron(BaseTacotron): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index ead3bf2b..d4e665e3 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -1,9 +1,8 @@ # coding: utf-8 -from typing import Dict +from typing import Dict, List, Union import torch -from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron): Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): - super().__init__(config) + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) - self.speaker_manager = speaker_manager - chars, self.config, _ = self.get_characters(config) - config.num_chars = len(chars) self.decoder_output_dim = config.out_channels # pass all config fields to `self` @@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron): self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use """Log training progress.""" - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) - logger.train_audios(steps, audios, ap.sample_rate) + logger.train_audios(steps, audios, self.ap.sample_rate) def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - figures, audios = self._create_logs(batch, outputs, ap) + figures, audios = self._create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b2e4be9e..a43e081c 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,30 +1,294 @@ import math -from dataclasses import dataclass, field +import os +from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch - +import torch.distributed as dist import torchaudio from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +class VitsDataset(TTSDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# MODEL DEFINITION +############################## + @dataclass class VitsArgs(Coqpit): @@ -36,7 +300,7 @@ class VitsArgs(Coqpit): Number of characters in the vocabulary. Defaults to 100. out_channels (int): - Number of output channels. Defaults to 513. + Number of output channels of the decoder. Defaults to 513. spec_segment_size (int): Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. @@ -171,6 +435,9 @@ class VitsArgs(Coqpit): speaker_encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + freeze_encoder (bool): Freeze the encoder weigths during training. Defaults to False. @@ -233,6 +500,7 @@ class VitsArgs(Coqpit): use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False @@ -268,109 +536,88 @@ class Vits(BaseTTS): >>> model = Vits(config) """ - # pylint: disable=dangerous-default-value - def __init__( self, config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None, ): - super().__init__(config) - - self.END2END = True - self.speaker_manager = speaker_manager - self.language_manager = language_manager - if config.__class__.__name__ == "VitsConfig": - # loading from VitsConfig - if "num_chars" not in config: - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - else: - self.config = config - config.model_args.num_chars = config.num_chars - args = self.config.model_args - elif isinstance(config, VitsArgs): - # loading from VitsArgs - self.config = config - args = config - else: - raise ValueError("config must be either a VitsConfig or VitsArgs") - - self.args = args + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) self.init_multispeaker(config) self.init_multilingual(config) - self.length_scale = args.length_scale - self.noise_scale = args.noise_scale - self.inference_noise_scale = args.inference_noise_scale - self.inference_noise_scale_dp = args.inference_noise_scale_dp - self.noise_scale_dp = args.noise_scale_dp - self.max_inference_len = args.max_inference_len - self.spec_segment_size = args.spec_segment_size + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size self.text_encoder = TextEncoder( - args.num_chars, - args.hidden_channels, - args.hidden_channels, - args.hidden_channels_ffn_text_encoder, - args.num_heads_text_encoder, - args.num_layers_text_encoder, - args.kernel_size_text_encoder, - args.dropout_p_text_encoder, + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( - args.out_channels, - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_posterior_encoder, - dilation_rate=args.dilation_rate_posterior_encoder, - num_layers=args.num_layers_posterior_encoder, + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, cond_channels=self.embedded_speaker_dim, ) self.flow = ResidualCouplingBlocks( - args.hidden_channels, - args.hidden_channels, - kernel_size=args.kernel_size_flow, - dilation_rate=args.dilation_rate_flow, - num_layers=args.num_layers_flow, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, cond_channels=self.embedded_speaker_dim, ) - if args.use_sdp: + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 192, 3, - args.dropout_p_duration_predictor, + self.args.dropout_p_duration_predictor, 4, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, + self.args.hidden_channels, 256, 3, - args.dropout_p_duration_predictor, + self.args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim, language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( - args.hidden_channels, + self.args.hidden_channels, 1, - args.resblock_type_decoder, - args.resblock_dilation_sizes_decoder, - args.resblock_kernel_sizes_decoder, - args.upsample_kernel_sizes_decoder, - args.upsample_initial_channel_decoder, - args.upsample_rates_decoder, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, inference_padding=0, cond_channels=self.embedded_speaker_dim, conv_pre_weight_norm=False, @@ -378,8 +625,8 @@ class Vits(BaseTTS): conv_post_bias=False, ) - if args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) + if self.args.init_discriminator: + self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -393,6 +640,7 @@ class Vits(BaseTTS): """ self.embedded_speaker_dim = 0 self.num_speakers = self.args.num_speakers + self.audio_transform = None if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers @@ -406,7 +654,7 @@ class Vits(BaseTTS): # TODO: make this a function if self.args.use_speaker_encoder_as_loss: if self.speaker_manager.speaker_encoder is None and ( - not config.speaker_encoder_model_path or not config.speaker_encoder_config_path + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path ): raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" @@ -420,11 +668,14 @@ class Vits(BaseTTS): and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] ): self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.audio_config["sample_rate"], - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) - else: - self.audio_transform = None + orig_freq=self.audio_config["sample_rate"], + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -458,6 +709,35 @@ class Vits(BaseTTS): self.embedded_language_dim = 0 self.emb_l = None + def get_aux_input(self, aux_input: Dict): + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" @@ -478,59 +758,55 @@ class Vits(BaseTTS): return sid, g, lid - def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) - return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) - def get_aux_input_from_test_sentences(self, sentence_info): - if hasattr(self.config, "model_args"): - config = self.config.model_args + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) else: - config = self.config + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn - # extract speaker and language info - text, speaker_name, style_wav, language_name = None, None, None, None - - if isinstance(sentence_info, list): - if len(sentence_info) == 1: - text = sentence_info[0] - elif len(sentence_info) == 2: - text, speaker_name = sentence_info - elif len(sentence_info) == 3: - text, speaker_name, style_wav = sentence_info - elif len(sentence_info) == 4: - text, speaker_name, style_wav, language_name = sentence_info - else: - text = sentence_info - - # get speaker id/d_vector - speaker_id, d_vector, language_id = None, None, None - if hasattr(self, "speaker_manager"): - if config.use_d_vector_file: - if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() - else: - d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) - elif config.use_speaker_embedding: - if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() - else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] - - # get language id - if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] - - return { - "text": text, - "speaker_id": speaker_id, - "style_wav": style_wav, - "d_vector": d_vector, - "language_id": language_id, - "language_name": language_name, - } - - def forward( + def forward( # pylint: disable=dangerous-default-value self, x: torch.tensor, x_lengths: torch.tensor, @@ -558,10 +834,23 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - y: :math:`[B, C, T_spec]` - y_lengths: :math:`[B]` - - waveform: :math:`[B, T_wav, 1]` + - waveform: :math:`[B, 1, T_wav]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) @@ -582,51 +871,22 @@ class Vits(BaseTTS): # flow layers z_p = self.flow(z, y_mask, g=g) - # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - with torch.no_grad(): - o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - # duration predictor - attn_durations = attn.sum(3) - if self.args.use_sdp: - loss_duration = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - attn_durations, - g=g.detach() if self.args.detach_dp_input and g is not None else g, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = loss_duration / torch.sum(x_mask) - else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask - log_durations = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - g=g.detach() if self.args.detach_dp_input and g is not None else g, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, + pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -649,28 +909,49 @@ class Vits(BaseTTS): { "model_outputs": o, "alignments": attn.squeeze(1), - "z": z, - "z_p": z_p, "m_p": m_p, "logs_p": logs_p, + "z": z, + "z_p": z_p, "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, } ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + ): # pylint: disable=dangerous-default-value """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + Shapes: - x: :math:`[B, T_seq]` - - d_vectors: :math:`[B, C, 1]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` """ sid, g, lid = self._set_cond_input(aux_input) - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: @@ -685,16 +966,24 @@ class Vits(BaseTTS): if self.args.use_sdp: logw = self.duration_predictor( - x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, ) else: - logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) @@ -747,39 +1036,17 @@ class Vits(BaseTTS): Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. """ - # pylint: disable=attribute-defined-outside-init - if optimizer_idx not in [0, 1]: - raise ValueError(" [!] Unexpected `optimizer_idx`.") - if self.args.freeze_encoder: - for param in self.text_encoder.parameters(): - param.requires_grad = False + self._freeze_layers() - if hasattr(self, "emb_l"): - for param in self.emb_l.parameters(): - param.requires_grad = False - - if self.args.freeze_PE: - for param in self.posterior_encoder.parameters(): - param.requires_grad = False - - if self.args.freeze_DP: - for param in self.duration_predictor.parameters(): - param.requires_grad = False - - if self.args.freeze_flow_decoder: - for param in self.flow.parameters(): - param.requires_grad = False - - if self.args.freeze_waveform_decoder: - for param in self.waveform_decoder.parameters(): - param.requires_grad = False + mel_lens = batch["mel_lens"] if optimizer_idx == 0: - text_input = batch["text_input"] - text_lengths = batch["text_lengths"] - mel_lengths = batch["mel_lengths"] - linear_input = batch["linear_input"] + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + spec_lens = batch["spec_lens"] + d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] language_ids = batch["language_ids"] @@ -787,69 +1054,86 @@ class Vits(BaseTTS): # generator pass outputs = self.forward( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - waveform.transpose(1, 2), + tokens, + token_lenghts, + spec, + spec_lens, + waveform, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) - # cache tensors for the discriminator - self.y_disc_cache = None - self.wav_seg_disc_cache = None - self.y_disc_cache = outputs["model_outputs"] - self.wav_seg_disc_cache = outputs["waveform_seg"] - - # compute discriminator scores and features - outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( - outputs["model_outputs"], outputs["waveform_seg"] - ) - - # compute losses - with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - waveform_hat=outputs["model_outputs"].float(), - waveform=outputs["waveform_seg"].float(), - z_p=outputs["z_p"].float(), - logs_q=outputs["logs_q"].float(), - m_p=outputs["m_p"].float(), - logs_p=outputs["logs_p"].float(), - z_len=mel_lengths, - scores_disc_fake=outputs["scores_disc_fake"], - feats_disc_fake=outputs["feats_disc_fake"], - feats_disc_real=outputs["feats_disc_real"], - loss_duration=outputs["loss_duration"], - use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb=outputs["gt_spk_emb"], - syn_spk_emb=outputs["syn_spk_emb"], - ) - - elif optimizer_idx == 1: - # discriminator pass - outputs = {} + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init # compute scores and features - outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc( - self.y_disc_cache.detach(), self.wav_seg_disc_cache + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] ) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( - outputs["scores_disc_real"], - outputs["scores_disc_fake"], + scores_disc_real, + scores_disc_fake, ) - return outputs, loss_dict + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=mel_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use - y_hat = outputs[0]["model_outputs"] - y = outputs[0]["waveform_seg"] + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name_prefix}/audio": sample_voice} - alignments = outputs[0]["alignments"] + alignments = outputs[1]["alignments"] align_img = alignments[0].data.cpu().numpy().T figures.update( @@ -857,7 +1141,6 @@ class Vits(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } ) - return figures, audios def train_log( @@ -876,19 +1159,69 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - ap = assets["audio_processor"] - self._log(ap, batch, outputs, "train") + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - ap = assets["audio_processor"] - return self._log(ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } @torch.no_grad() - def test_run(self, ap) -> Tuple[Dict, Dict]: + def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -901,55 +1234,170 @@ class Vits(BaseTTS): test_figures = {} test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): - try: - aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( - self, - aux_inputs["text"], - self.config, - "cuda" in str(next(self.parameters()).device), - ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - language_id=aux_inputs["language_id"], - language_name=aux_inputs["language_name"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) - return test_figures, test_audios + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.d_vectors + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if ( + self.language_manager is not None + and self.language_manager.language_id_mapping + and self.args.use_language_embedding + ): + language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + # compute spectrograms + batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + batch["mel"] = spec_to_mel( + spec=batch["spec"], + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = VitsDataset( + samples=samples, + # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # sampler for DDP + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + + # Weighted samplers + # TODO: make this DDP amenable + assert not ( + num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) + ), "language_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) + ), "speaker_weighted_sampler is not supported with DistributedSampler" + + if sampler is None: + if getattr(config, "use_language_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_language_weighted_sampler(dataset.samples) + elif getattr(config, "use_speaker_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_speaker_weighted_sampler(dataset.samples) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. - Returns: List: optimizers. """ - gen_parameters = chain( - self.text_encoder.parameters(), - self.posterior_encoder.parameters(), - self.flow.parameters(), - self.duration_predictor.parameters(), - self.waveform_decoder.parameters(), - ) - # add the speaker embedding layer - if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file: - gen_parameters = chain(gen_parameters, self.emb_g.parameters()) - # add the language embedding layer - if hasattr(self, "emb_l") and self.args.use_language_embedding: - gen_parameters = chain(gen_parameters, self.emb_l.parameters()) + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - optimizer0 = get_optimizer( + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) - optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) return [optimizer0, optimizer1] def get_lr(self) -> List: @@ -958,7 +1406,7 @@ class Vits(BaseTTS): Returns: List: learning rates for each optimizer. """ - return [self.config.lr_gen, self.config.lr_disc] + return [self.config.lr_disc, self.config.lr_gen] def get_scheduler(self, optimizer) -> List: """Set the schedulers for each optimizer. @@ -969,9 +1417,9 @@ class Vits(BaseTTS): Returns: List: Schedulers, one for each optimizer. """ - scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler0, scheduler1] + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] def get_criterion(self): """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in @@ -981,40 +1429,14 @@ class Vits(BaseTTS): VitsGeneratorLoss, ) - return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] - - @staticmethod - def make_symbols(config): - """Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the - whole training and inference steps.""" - _pad = config.characters["pad"] - _punctuations = config.characters["punctuations"] - _letters = config.characters["characters"] - _letters_ipa = config.characters["phonemes"] - symbols = [_pad] + list(_punctuations) + list(_letters) - if config.use_phonemes: - symbols += list(_letters_ipa) - return symbols - - @staticmethod - def get_characters(config: Coqpit): - if config.characters is not None: - symbols = Vits.make_symbols(config) - else: - from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel - parse_symbols, - phonemes, - symbols, - ) - - config.characters = parse_symbols() - if config.use_phonemes: - symbols = phonemes - num_chars = len(symbols) + getattr(config, "add_blank", False) - return symbols, config, num_chars + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] def load_checkpoint( - self, config, checkpoint_path, eval=False + self, + config, + checkpoint_path, + eval=False, + strict=True, ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = torch.load(checkpoint_path, map_location=torch.device("cpu")) @@ -1022,7 +1444,97 @@ class Vits(BaseTTS): # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} - self.load_state_dict(state["model"]) + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + if eval: self.eval() assert not self.training + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + + ap = AudioProcessor.init_from_config(config, verbose=verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_speaker_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + + +################################## +# VITS CHARACTERS +################################## + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) diff --git a/TTS/tts/tf/README.md b/TTS/tts/tf/README.md deleted file mode 100644 index 0f9d58e9..00000000 --- a/TTS/tts/tf/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## Utilities to Convert Models to Tensorflow2 -Here there are experimental utilities to convert trained Torch models to Tensorflow (2.2>=). - -Converting Torch models to TF enables all the TF toolkit to be used for better deployment and device specific optimizations. - -Note that we do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome. - -To see how you can use TF model at inference, check the notebook. - -This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. - - -### Converting a Model -- Run ```convert_tacotron2_torch_to_tf.py --torch_model_path /path/to/torch/model.pth.tar --config_path /path/to/model/config.json --output_path /path/to/output/tf/model``` with the right arguments. - -### Known issues ans limitations -- We use a custom model load/save mechanism which enables us to store model related information with models weights. (Similar to Torch). However, it is prone to random errors. -- Current TF model implementation is slightly slower than Torch model. Hopefully, it'll get better with improving TF support for eager mode and ```tf.function```. -- TF implementation of Tacotron2 only supports regular Tacotron2 as in the paper. -- You can only convert models trained after TF model implementation since model layers has been updated in Torch model. diff --git a/TTS/tts/tf/layers/tacotron/common_layers.py b/TTS/tts/tf/layers/tacotron/common_layers.py deleted file mode 100644 index a6b87981..00000000 --- a/TTS/tts/tf/layers/tacotron/common_layers.py +++ /dev/null @@ -1,301 +0,0 @@ -import tensorflow as tf -from tensorflow import keras -from tensorflow.python.ops import math_ops - -# from tensorflow_addons.seq2seq import BahdanauAttention - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg - - -class Linear(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super().__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") - self.activation = keras.layers.ReLU() - - def call(self, x): - """ - shapes: - x: B x T x C - """ - return self.activation(self.linear_layer(x)) - - -class LinearBN(keras.layers.Layer): - def __init__(self, units, use_bias, **kwargs): - super().__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") - self.batch_normalization = keras.layers.BatchNormalization( - axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization" - ) - self.activation = keras.layers.ReLU() - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - out = self.linear_layer(x) - out = self.batch_normalization(out, training=training) - return self.activation(out) - - -class Prenet(keras.layers.Layer): - def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs): - super().__init__(**kwargs) - self.prenet_type = prenet_type - self.prenet_dropout = prenet_dropout - self.linear_layers = [] - if prenet_type == "bn": - self.linear_layers += [ - LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) - ] - elif prenet_type == "original": - self.linear_layers += [ - Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) - ] - else: - raise RuntimeError(" [!] Unknown prenet type.") - if prenet_dropout: - self.dropout = keras.layers.Dropout(rate=0.5) - - def call(self, x, training=None): - """ - shapes: - x: B x T x C - """ - for linear in self.linear_layers: - if self.prenet_dropout: - x = self.dropout(linear(x), training=training) - else: - x = linear(x) - return x - - -def _sigmoid_norm(score): - attn_weights = tf.nn.sigmoid(score) - attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) - return attn_weights - - -class Attention(keras.layers.Layer): - """TODO: implement forward_attention - TODO: location sensitive attention - TODO: implement attention windowing""" - - def __init__( - self, - attn_dim, - use_loc_attn, - loc_attn_n_filters, - loc_attn_kernel_size, - use_windowing, - norm, - use_forward_attn, - use_trans_agent, - use_forward_attn_mask, - **kwargs, - ): - super().__init__(**kwargs) - self.use_loc_attn = use_loc_attn - self.loc_attn_n_filters = loc_attn_n_filters - self.loc_attn_kernel_size = loc_attn_kernel_size - self.use_windowing = use_windowing - self.norm = norm - self.use_forward_attn = use_forward_attn - self.use_trans_agent = use_trans_agent - self.use_forward_attn_mask = use_forward_attn_mask - self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer") - self.inputs_layer = tf.keras.layers.Dense( - attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer" - ) - self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer") - if use_loc_attn: - self.location_conv1d = keras.layers.Conv1D( - filters=loc_attn_n_filters, - kernel_size=loc_attn_kernel_size, - padding="same", - use_bias=False, - name="location_layer/location_conv1d", - ) - self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense") - if norm == "softmax": - self.norm_func = tf.nn.softmax - elif norm == "sigmoid": - self.norm_func = _sigmoid_norm - else: - raise ValueError("Unknown value for attention norm type") - - def init_states(self, batch_size, value_length): - states = [] - if self.use_loc_attn: - attention_cum = tf.zeros([batch_size, value_length]) - attention_old = tf.zeros([batch_size, value_length]) - states = [attention_cum, attention_old] - if self.use_forward_attn: - alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1) - states.append(alpha) - return tuple(states) - - def process_values(self, values): - """cache values for decoder iterations""" - # pylint: disable=attribute-defined-outside-init - self.processed_values = self.inputs_layer(values) - self.values = values - - def get_loc_attn(self, query, states): - """compute location attention, query layer and - unnorm. attention weights""" - attention_cum, attention_old = states[:2] - attn_cat = tf.stack([attention_old, attention_cum], axis=2) - - processed_query = self.query_layer(tf.expand_dims(query, 1)) - processed_attn = self.location_dense(self.location_conv1d(attn_cat)) - score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def get_attn(self, query): - """compute query layer and unnormalized attention weights""" - processed_query = self.query_layer(tf.expand_dims(query, 1)) - score = self.v(tf.nn.tanh(self.processed_values + processed_query)) - score = tf.squeeze(score, axis=2) - return score, processed_query - - def apply_score_masking(self, score, mask): # pylint: disable=no-self-use - """ignore sequence paddings""" - padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) - # Bias so padding positions do not contribute to attention distribution. - score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32) - return score - - def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use - # forward attention - fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) - # compute transition potentials - new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment - # renormalize attention weights - new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) - return new_alpha - - def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): - states = [] - if self.use_loc_attn: - states = [old_states[0] + scores_norm, attn_weights] - if self.use_forward_attn: - states.append(new_alpha) - return tuple(states) - - def call(self, query, states): - """ - shapes: - query: B x D - """ - if self.use_loc_attn: - score, _ = self.get_loc_attn(query, states) - else: - score, _ = self.get_attn(query) - - # TODO: masking - # if mask is not None: - # self.apply_score_masking(score, mask) - # attn_weights shape == (batch_size, max_length, 1) - - # normalize attention scores - scores_norm = self.norm_func(score) - attn_weights = scores_norm - - # apply forward attention - new_alpha = None - if self.use_forward_attn: - new_alpha = self.apply_forward_attention(attn_weights, states[-1]) - attn_weights = new_alpha - - # update states tuple - # states = (cum_attn_weights, attn_weights, new_alpha) - states = self.update_states(states, scores_norm, attn_weights, new_alpha) - - # context_vector shape after sum == (batch_size, hidden_size) - context_vector = tf.matmul( - tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False - ) - context_vector = tf.squeeze(context_vector, axis=1) - return context_vector, attn_weights, states - - -# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): -# dtype = processed_query.dtype -# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] -# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) - - -# class LocationSensitiveAttention(BahdanauAttention): -# def __init__(self, -# units, -# memory=None, -# memory_sequence_length=None, -# normalize=False, -# probability_fn="softmax", -# kernel_initializer="glorot_uniform", -# dtype=None, -# name="LocationSensitiveAttention", -# location_attention_filters=32, -# location_attention_kernel_size=31): - -# super( self).__init__(units=units, -# memory=memory, -# memory_sequence_length=memory_sequence_length, -# normalize=normalize, -# probability_fn='softmax', ## parent module default -# kernel_initializer=kernel_initializer, -# dtype=dtype, -# name=name) -# if probability_fn == 'sigmoid': -# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) -# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) -# self.location_dense = keras.layers.Dense(units, use_bias=False) -# # self.v = keras.layers.Dense(1, use_bias=True) - -# def _location_sensitive_score(self, processed_query, keys, processed_loc): -# processed_query = tf.expand_dims(processed_query, 1) -# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) - -# def _location_sensitive(self, alignment_cum, alignment_old): -# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) -# return self.location_dense(self.location_conv(alignment_cat)) - -# def _sigmoid_normalization(self, score): -# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) - -# # def _apply_masking(self, score, mask): -# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) -# # # Bias so padding positions do not contribute to attention distribution. -# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) -# # return score - -# def _calculate_attention(self, query, state): -# alignment_cum, alignment_old = state[:2] -# processed_query = self.query_layer( -# query) if self.query_layer else query -# processed_loc = self._location_sensitive(alignment_cum, alignment_old) -# score = self._location_sensitive_score( -# processed_query, -# self.keys, -# processed_loc) -# alignment = self.probability_fn(score, state) -# alignment_cum = alignment_cum + alignment -# state[0] = alignment_cum -# state[1] = alignment -# return alignment, state - -# def compute_context(self, alignments): -# expanded_alignments = tf.expand_dims(alignments, 1) -# context = tf.matmul(expanded_alignments, self.values) -# context = tf.squeeze(context, [1]) -# return context - -# # def call(self, query, state): -# # alignment, next_state = self._calculate_attention(query, state) -# # return alignment, next_state diff --git a/TTS/tts/tf/layers/tacotron/tacotron2.py b/TTS/tts/tf/layers/tacotron/tacotron2.py deleted file mode 100644 index 1fe679d2..00000000 --- a/TTS/tts/tf/layers/tacotron/tacotron2.py +++ /dev/null @@ -1,322 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet -from TTS.tts.tf.utils.tf_utils import shape_list - - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg -class ConvBNBlock(keras.layers.Layer): - def __init__(self, filters, kernel_size, activation, **kwargs): - super().__init__(**kwargs) - self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d") - self.batch_normalization = keras.layers.BatchNormalization( - axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization" - ) - self.dropout = keras.layers.Dropout(rate=0.5, name="dropout") - self.activation = keras.layers.Activation(activation, name="activation") - - def call(self, x, training=None): - o = self.convolution1d(x) - o = self.batch_normalization(o, training=training) - o = self.activation(o) - o = self.dropout(o, training=training) - return o - - -class Postnet(keras.layers.Layer): - def __init__(self, output_filters, num_convs, **kwargs): - super().__init__(**kwargs) - self.convolutions = [] - self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0")) - for idx in range(1, num_convs - 1): - self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}")) - self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}")) - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - return o - - -class Encoder(keras.layers.Layer): - def __init__(self, output_input_dim, **kwargs): - super().__init__(**kwargs) - self.convolutions = [] - for idx in range(3): - self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}")) - self.lstm = keras.layers.Bidirectional( - keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm" - ) - - def call(self, x, training=None): - o = x - for layer in self.convolutions: - o = layer(o, training=training) - o = self.lstm(o) - return o - - -class Decoder(keras.layers.Layer): - # pylint: disable=unused-argument - def __init__( - self, - frame_dim, - r, - attn_type, - use_attn_win, - attn_norm, - prenet_type, - prenet_dropout, - use_forward_attn, - use_trans_agent, - use_forward_attn_mask, - use_location_attn, - attn_K, - separate_stopnet, - speaker_emb_dim, - enable_tflite, - **kwargs, - ): - super().__init__(**kwargs) - self.frame_dim = frame_dim - self.r_init = tf.constant(r, dtype=tf.int32) - self.r = tf.constant(r, dtype=tf.int32) - self.output_dim = r * self.frame_dim - self.separate_stopnet = separate_stopnet - self.enable_tflite = enable_tflite - - # layer constants - self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) - self.stop_thresh = tf.constant(0.5, dtype=tf.float32) - - # model dimensions - self.query_dim = 1024 - self.decoder_rnn_dim = 1024 - self.prenet_dim = 256 - self.attn_dim = 128 - self.p_attention_dropout = 0.1 - self.p_decoder_dropout = 0.1 - - self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet") - self.attention_rnn = keras.layers.LSTMCell( - self.query_dim, - use_bias=True, - name="attention_rnn", - ) - self.attention_rnn_dropout = keras.layers.Dropout(0.5) - - # TODO: implement other attn options - self.attention = Attention( - attn_dim=self.attn_dim, - use_loc_attn=True, - loc_attn_n_filters=32, - loc_attn_kernel_size=31, - use_windowing=False, - norm=attn_norm, - use_forward_attn=use_forward_attn, - use_trans_agent=use_trans_agent, - use_forward_attn_mask=use_forward_attn_mask, - name="attention", - ) - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn") - self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer") - self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer") - - def set_max_decoder_steps(self, new_max_steps): - self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) - - def set_r(self, new_r): - self.r = tf.constant(new_r, dtype=tf.int32) - self.output_dim = self.frame_dim * new_r - - def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): - zero_frame = tf.zeros([batch_size, self.frame_dim]) - zero_context = tf.zeros([batch_size, memory_dim]) - attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) - attention_states = self.attention.init_states(batch_size, memory_length) - return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states - - def step(self, prenet_next, states, memory_seq_length=None, training=None): - _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states - attention_rnn_input = tf.concat([prenet_next, context_next], -1) - attention_rnn_output, attention_rnn_state = self.attention_rnn( - attention_rnn_input, attention_rnn_state, training=training - ) - attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) - context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) - decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) - decoder_rnn_output, decoder_rnn_state = self.decoder_rnn( - decoder_rnn_input, decoder_rnn_state, training=training - ) - decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) - linear_projection_input = tf.concat([decoder_rnn_output, context], -1) - output_frame = self.linear_projection(linear_projection_input, training=training) - stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) - stopnet_output = self.stopnet(stopnet_input, training=training) - output_frame = output_frame[:, : self.r * self.frame_dim] - states = ( - output_frame[:, self.frame_dim * (self.r - 1) :], - context, - attention_rnn_state, - decoder_rnn_state, - attention_states, - ) - return output_frame, stopnet_output, states, attention - - def decode(self, memory, states, frames, memory_seq_length=None): - B, _, _ = shape_list(memory) - num_iter = shape_list(frames)[1] // self.r - # init states - frame_zero = tf.expand_dims(states[0], 1) - frames = tf.concat([frame_zero, frames], axis=1) - outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) - attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) - # pre-computes - self.attention.process_values(memory) - prenet_output = self.prenet(frames, training=True) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): - prenet_next = prenet_output[:, step] - output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions - - _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop( - lambda *arg: True, - _body, - loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=num_iter, - ) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference(self, memory, states): - B, _, _ = shape_list(memory) - # init states - outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, attention = self.step(prenet_next, states, None, training=False) - stop_token = tf.math.sigmoid(stop_token) - outputs = outputs.write(step, output) - attentions = attentions.write(step, attention) - stop_tokens = stop_tokens.write(step, stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag - - cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - _, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop( - cond, - _body, - loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps, - ) - - outputs = outputs.stack() - attentions = attentions.stack() - stop_tokens = stop_tokens.stack() - - outputs = tf.transpose(outputs, [1, 0, 2]) - attentions = tf.transpose(attentions, [1, 0, 2]) - stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) - stop_tokens = tf.squeeze(stop_tokens, axis=2) - outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def decode_inference_tflite(self, memory, states): - """Inference with TF-Lite compatibility. It assumes - batch_size is 1""" - # init states - # dynamic_shape is not supported in TFLite - outputs = tf.TensorArray( - dtype=tf.float32, - size=self.max_decoder_steps, - element_shape=tf.TensorShape([self.output_dim]), - clear_after_read=False, - dynamic_size=False, - ) - # stop_flags = tf.TensorArray(dtype=tf.bool, - # size=self.max_decoder_steps, - # element_shape=tf.TensorShape( - # []), - # clear_after_read=False, - # dynamic_size=False) - attentions = () - stop_tokens = () - - # pre-computes - self.attention.process_values(memory) - - # iter vars - stop_flag = tf.constant(False, dtype=tf.bool) - step_count = tf.constant(0, dtype=tf.int32) - - def _body(step, memory, states, outputs, stop_flag): - frame_next = states[0] - prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, _ = self.step(prenet_next, states, None, training=False) - stop_token = tf.math.sigmoid(stop_token) - stop_flag = tf.greater(stop_token, self.stop_thresh) - stop_flag = tf.reduce_all(stop_flag) - # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) - - outputs = outputs.write(step, tf.reshape(output, [-1])) - return step + 1, memory, states, outputs, stop_flag - - cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - step_count, memory, states, outputs, stop_flag = tf.while_loop( - cond, - _body, - loop_vars=(step_count, memory, states, outputs, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps, - ) - - outputs = outputs.stack() - outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter - outputs = tf.expand_dims(outputs, axis=[0]) - outputs = tf.transpose(outputs, [1, 0, 2]) - outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) - return outputs, stop_tokens, attentions - - def call(self, memory, states, frames=None, memory_seq_length=None, training=False): - if training: - return self.decode(memory, states, frames, memory_seq_length) - if self.enable_tflite: - return self.decode_inference_tflite(memory, states) - return self.decode_inference(memory, states) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py deleted file mode 100644 index 7a1d695d..00000000 --- a/TTS/tts/tf/models/tacotron2.py +++ /dev/null @@ -1,116 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet -from TTS.tts.tf.utils.tf_utils import shape_list - - -# pylint: disable=too-many-ancestors, abstract-method -class Tacotron2(keras.models.Model): - def __init__( - self, - num_chars, - num_speakers, - r, - out_channels=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="softmax", - attn_K=4, - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=False, - ): - super().__init__() - self.r = r - self.decoder_output_dim = decoder_output_dim - self.out_channels = out_channels - self.bidirectional_decoder = bidirectional_decoder - self.num_speakers = num_speakers - self.speaker_embed_dim = 256 - self.enable_tflite = enable_tflite - - self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding") - self.encoder = Encoder(512, name="encoder") - # TODO: most of the decoder args have no use at the momment - self.decoder = Decoder( - decoder_output_dim, - r, - attn_type=attn_type, - use_attn_win=attn_win, - attn_norm=attn_norm, - prenet_type=prenet_type, - prenet_dropout=prenet_dropout, - use_forward_attn=forward_attn, - use_trans_agent=trans_agent, - use_forward_attn_mask=forward_attn_mask, - use_location_attn=location_attn, - attn_K=attn_K, - separate_stopnet=separate_stopnet, - speaker_emb_dim=self.speaker_embed_dim, - name="decoder", - enable_tflite=enable_tflite, - ) - self.postnet = Postnet(out_channels, 5, name="postnet") - - @tf.function(experimental_relax_shapes=True) - def call(self, characters, text_lengths=None, frames=None, training=None): - if training: - return self.training(characters, text_lengths, frames) - if not training: - return self.inference(characters) - raise RuntimeError(" [!] Set model training mode True or False") - - def training(self, characters, text_lengths, frames): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=True) - encoder_output = self.encoder(embedding_vectors, training=True) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder( - encoder_output, decoder_states, frames, text_lengths, training=True - ) - postnet_frames = self.postnet(decoder_frames, training=True) - output_frames = decoder_frames + postnet_frames - return decoder_frames, output_frames, attentions, stop_tokens - - def inference(self, characters): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=False) - encoder_output = self.encoder(embedding_vectors, training=False) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) - postnet_frames = self.postnet(decoder_frames, training=False) - output_frames = decoder_frames + postnet_frames - print(output_frames.shape) - return decoder_frames, output_frames, attentions, stop_tokens - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, None], dtype=tf.int32), - ], - ) - def inference_tflite(self, characters): - B, T = shape_list(characters) - embedding_vectors = self.embedding(characters, training=False) - encoder_output = self.encoder(embedding_vectors, training=False) - decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) - postnet_frames = self.postnet(decoder_frames, training=False) - output_frames = decoder_frames + postnet_frames - print(output_frames.shape) - return decoder_frames, output_frames, attentions, stop_tokens - - def build_inference( - self, - ): - # TODO: issue https://github.com/PyCQA/pylint/issues/3613 - input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg - self(input_ids) diff --git a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py deleted file mode 100644 index 2c615a7d..00000000 --- a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -import tensorflow as tf - -# NOTE: linter has a problem with the current TF release -# pylint: disable=no-value-for-parameter -# pylint: disable=unexpected-keyword-arg - - -def tf_create_dummy_inputs(): - """Create dummy inputs for TF Tacotron2 model""" - batch_size = 4 - max_input_length = 32 - max_mel_length = 128 - pad = 1 - n_chars = 24 - input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) - input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size]) - input_lengths[-1] = max_input_length - input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) - mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) - mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size]) - mel_lengths[-1] = max_mel_length - mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) - return input_ids, input_lengths, mel_outputs, mel_lengths - - -def compare_torch_tf(torch_tensor, tf_tensor): - """Compute the average absolute difference b/w torch and tf tensors""" - return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() - - -def convert_tf_name(tf_name): - """Convert certain patterns in TF layer names to Torch patterns""" - tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(":0", "") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") - tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") - tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") - tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") - tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") - tf_name_tmp = tf_name_tmp.replace("/", ".") - return tf_name_tmp - - -def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): - """Transfer weigths from torch state_dict to TF variables""" - print(" > Passing weights from Torch to TF ...") - for tf_var in tf_vars: - torch_var_name = var_map_dict[tf_var.name] - print(f" | > {tf_var.name} <-- {torch_var_name}") - # if tuple, it is a bias variable - if not isinstance(torch_var_name, tuple): - torch_layer_name = ".".join(torch_var_name.split(".")[-2:]) - torch_weight = state_dict[torch_var_name] - if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name: - # out_dim, in_dim, filter -> filter, in_dim, out_dim - numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() - elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - # if variable is for bidirectional lstm and it is a bias vector there - # needs to be pre-defined two matching torch bias vectors - elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name: - bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] - assert len(bias_vectors) == 2 - numpy_weight = bias_vectors[0] + bias_vectors[1] - elif "rnn" in tf_var.name and "kernel" in tf_var.name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - elif "rnn" in tf_var.name and "bias" in tf_var.name: - bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] - assert len(bias_vectors) == 2 - numpy_weight = bias_vectors[0] + bias_vectors[1] - elif "linear_layer" in torch_layer_name and "weight" in torch_var_name: - numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - else: - numpy_weight = torch_weight.detach().cpu().numpy() - assert np.all( - tf_var.shape == numpy_weight.shape - ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" - tf.keras.backend.set_value(tf_var, numpy_weight) - return tf_vars - - -def load_tf_vars(model_tf, tf_vars): - for tf_var in tf_vars: - model_tf.get_layer(tf_var.name).set_weights(tf_var) - return model_tf diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py deleted file mode 100644 index 681a9457..00000000 --- a/TTS/tts/tf/utils/generic_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -import datetime -import importlib -import pickle - -import fsspec -import numpy as np -import tensorflow as tf - - -def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): - state = { - "model": model.weights, - "optimizer": optimizer, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - "r": r, - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - try: - chkp_var_value = chkp_var_dict[layer_name] - except KeyError: - class_name = list(chkp_var_dict.keys())[0].split("/")[0] - layer_name = f"{class_name}/{layer_name}" - chkp_var_value = chkp_var_dict[layer_name] - - tf.keras.backend.set_value(tf_var, chkp_var_value) - if "r" in checkpoint.keys(): - model.decoder.set_r(checkpoint["r"]) - return model - - -def sequence_mask(sequence_length, max_len=None): - if max_len is None: - max_len = sequence_length.max() - batch_size = sequence_length.size(0) - seq_range = np.empty([0, max_len], dtype=np.int8) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_range_expand = seq_range_expand.type_as(sequence_length) - seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) - # B x T_max - return seq_range_expand < seq_length_expand - - -# @tf.custom_gradient -def check_gradient(x, grad_clip): - x_normed = tf.clip_by_norm(x, grad_clip) - grad_norm = tf.norm(grad_clip) - return x_normed, grad_norm - - -def count_parameters(model, c): - try: - return model.count_params() - except RuntimeError: - input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32")) - input_lengths = np.random.randint(100, 129, (8,)) - input_lengths[-1] = 128 - input_lengths = tf.convert_to_tensor(input_lengths.astype("int32")) - mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32") - mel_spec = tf.convert_to_tensor(mel_spec) - speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None - _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids) - return model.count_params() - - -def setup_model(num_chars, num_speakers, c, enable_tflite=False): - print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower()) - MyModel = getattr(MyModel, c.model) - if c.model.lower() in "tacotron": - raise NotImplementedError(" [!] Tacotron model is not ready.") - # tacotron2 - model = MyModel( - num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - out_channels=c.audio["num_mels"], - decoder_output_dim=c.audio["num_mels"], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - enable_tflite=enable_tflite, - ) - return model diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py deleted file mode 100644 index de6acff9..00000000 --- a/TTS/tts/tf/utils/io.py +++ /dev/null @@ -1,45 +0,0 @@ -import datetime -import pickle - -import fsspec -import tensorflow as tf - - -def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): - state = { - "model": model.weights, - "optimizer": optimizer, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - "r": r, - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - try: - chkp_var_value = chkp_var_dict[layer_name] - except KeyError: - class_name = list(chkp_var_dict.keys())[0].split("/")[0] - layer_name = f"{class_name}/{layer_name}" - chkp_var_value = chkp_var_dict[layer_name] - - tf.keras.backend.set_value(tf_var, chkp_var_value) - if "r" in checkpoint.keys(): - model.decoder.set_r(checkpoint["r"]) - return model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/TTS/tts/tf/utils/tf_utils.py b/TTS/tts/tf/utils/tf_utils.py deleted file mode 100644 index 558936d5..00000000 --- a/TTS/tts/tf/utils/tf_utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import tensorflow as tf - - -def shape_list(x): - """Deal with dynamic shape in tensorflow cleanly.""" - static = x.shape.as_list() - dynamic = tf.shape(x) - return [dynamic[i] if s is None else s for i, s in enumerate(static)] diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py deleted file mode 100644 index 2f76aa50..00000000 --- a/TTS/tts/tf/utils/tflite.py +++ /dev/null @@ -1,27 +0,0 @@ -import fsspec -import tensorflow as tf - - -def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True): - """Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is - provided, else return TFLite model.""" - - concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) - converter.experimental_new_converter = experimental_converter - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite_model = converter.convert() - print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") - if output_path is not None: - # same model binary if outputpath is provided - with fsspec.open(output_path, "wb") as f: - f.write(tflite_model) - return None - return tflite_model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index b0a010b0..c2e7f561 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,40 +57,65 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): """Segment each sample in a batch based on the provided segment indices Args: x (torch.tensor): Input tensor. segment_indices (torch.tensor): Segment indices. segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): index_start = segment_indices[i] index_end = index_start + segment_size - segments[i] = x[i, :, index_start:index_end] + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] return segments -def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4): +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): """Create random segments based on the input lengths. Args: x (torch.tensor): Input tensor. x_lengths (torch.tensor): Input lengths. segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ + _x_lenghts = x_lengths.clone() B, _, T = x.size() - if x_lengths is None: - x_lengths = T - max_idxs = x_lengths - segment_size + 1 - assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() ret = segment(x, segment_indices, segment_size) return ret, segment_indices diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index fc7eec57..19708c13 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -8,6 +8,8 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler +from TTS.config import check_config_and_model_args + class LanguageManager: """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information @@ -98,6 +100,20 @@ class LanguageManager: """ self._save_json(file_path, self.language_id_mapping) + @staticmethod + def init_from_config(config: Coqpit) -> "LanguageManager": + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + language_manager = None + if check_config_and_model_args(config, "use_language_embedding", True): + if config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + language_manager = LanguageManager(config=config) + return language_manager + def _set_file_path(path): """Find the language_ids.json under the given path or the above it. @@ -113,7 +129,7 @@ def _set_file_path(path): def get_language_weighted_sampler(items: list): - language_names = np.array([item[3] for item in items]) + language_names = np.array([item["language"] for item in items]) unique_language_names = np.unique(language_names).tolist() language_ids = [unique_language_names.index(l) for l in language_names] language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 07076d90..99d653e6 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -9,7 +9,7 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler -from TTS.config import load_config +from TTS.config import get_from_config_or_model_args_with_default, load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.utils.audio import AudioProcessor @@ -118,7 +118,7 @@ class SpeakerManager: Returns: Tuple[Dict, int]: speaker IDs and number of speakers. """ - speakers = sorted({item[2] for item in items}) + speakers = sorted({item["speaker_name"] for item in items}) speaker_ids = {name: i for i, name in enumerate(speakers)} num_speakers = len(speaker_ids) return speaker_ids, num_speakers @@ -318,6 +318,42 @@ class SpeakerManager: # TODO: implement speaker encoder raise NotImplementedError + @staticmethod + def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": + """Initialize a speaker manager from config + + Args: + config (Coqpit): Config object. + samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names. + Defaults to None. + + Returns: + SpeakerEncoder: Speaker encoder object. + """ + speaker_manager = None + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): + if samples: + speaker_manager = SpeakerManager(data_items=samples) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) + + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) + return speaker_manager + def _set_file_path(path): """Find the speakers.json under the given path or the above it. @@ -414,7 +450,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, def get_speaker_weighted_sampler(items: list): - speaker_names = np.array([item[2] for item in items]) + speaker_names = np.array([item["speaker_name"] for item in items]) unique_speaker_names = np.unique(speaker_names).tolist() speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 883efdb8..ab2c6991 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -8,7 +8,7 @@ from torch.autograd import Variable def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) return gauss / gauss.sum() @@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 + C1 = 0.01**2 + C2 = 0.03**2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 24b747be..b6e19ab4 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,46 +1,9 @@ -import os from typing import Dict import numpy as np -import pkg_resources import torch from torch import nn -from .text import phoneme_to_sequence, text_to_sequence - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable -if "tensorflow" in installed or "tensorflow-gpu" in installed: - import tensorflow as tf - - -def text_to_seq(text, CONFIG, custom_symbols=None, language=None): - text_cleaner = [CONFIG.text_cleaner] - # text ot phonemes to sequence vector - if CONFIG.use_phonemes: - seq = np.asarray( - phoneme_to_sequence( - text, - text_cleaner, - language if language else CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters, - add_blank=CONFIG.add_blank, - use_espeak_phonemes=CONFIG.use_espeak_phonemes, - custom_symbols=custom_symbols, - ), - dtype=np.int32, - ) - else: - seq = np.asarray( - text_to_sequence( - text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols - ), - dtype=np.int32, - ) - return seq - def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: @@ -51,13 +14,6 @@ def numpy_to_torch(np_array, dtype, cuda=False): return tensor -def numpy_to_tf(np_array, dtype): - if np_array is None: - return None - tensor = tf.convert_to_tensor(np_array, dtype=dtype) - return tensor - - def compute_style_mel(style_wav, ap, cuda=False): style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) if cuda: @@ -103,53 +59,6 @@ def run_model_torch( return outputs -def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None): - if CONFIG.gst and style_mel is not None: - raise NotImplementedError(" [!] GST inference not implemented for TF") - if speaker_id is not None: - raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") - # TODO: handle multispeaker case - decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False) - return decoder_output, postnet_output, alignments, stop_tokens - - -def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None): - if CONFIG.gst and style_mel is not None: - raise NotImplementedError(" [!] GST inference not implemented for TfLite") - if speaker_id is not None: - raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite") - # get input and output details - input_details = model.get_input_details() - output_details = model.get_output_details() - # reshape input tensor for the new input shape - model.resize_tensor_input(input_details[0]["index"], inputs.shape) - model.allocate_tensors() - detail = input_details[0] - # input_shape = detail['shape'] - model.set_tensor(detail["index"], inputs) - # run the model - model.invoke() - # collect outputs - decoder_output = model.get_tensor(output_details[0]["index"]) - postnet_output = model.get_tensor(output_details[1]["index"]) - # tflite model only returns feature frames - return decoder_output, postnet_output, None, None - - -def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): - postnet_output = postnet_output[0].numpy() - decoder_output = decoder_output[0].numpy() - alignment = alignments[0].numpy() - stop_tokens = stop_tokens[0].numpy() - return postnet_output, decoder_output, alignment, stop_tokens - - -def parse_outputs_tflite(postnet_output, decoder_output): - postnet_output = postnet_output[0] - decoder_output = decoder_output[0] - return postnet_output, decoder_output - - def trim_silence(wav, ap): return wav[: ap.find_endpoint(wav)] @@ -204,16 +113,12 @@ def synthesis( text, CONFIG, use_cuda, - ap, speaker_id=None, style_wav=None, - enable_eos_bos_chars=False, # pylint: disable=unused-argument use_griffin_lim=False, do_trim_silence=False, d_vector=None, language_id=None, - language_name=None, - backend="torch", ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -231,9 +136,6 @@ def synthesis( use_cuda (bool): Enable/disable CUDA. - ap (TTS.tts.utils.audio.AudioProcessor): - The audio processor for extracting features and pre/post-processing audio. - speaker_id (int): Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. @@ -251,74 +153,51 @@ def synthesis( language_id (int): Language ID passed to the language embedding layer in multi-langual model. Defaults to None. - - language_name (str): - Language name corresponding to the language code used by the phonemizer. Defaults to None. - - backend (str): - tf or torch. Defaults to "torch". """ # GST processing style_mel = None - custom_symbols = None - if style_wav: - style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) - elif CONFIG.has("gst") and CONFIG.gst and not style_wav: - if CONFIG.gst.gst_style_input_weights: - style_mel = CONFIG.gst.gst_style_input_weights - if hasattr(model, "make_symbols"): - custom_symbols = model.make_symbols(CONFIG) - # preprocess the given text - text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name) + if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + # convert text to sequence of token IDs + text_inputs = np.asarray( + model.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) # pass tensors to backend - if backend == "torch": - if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) - if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=use_cuda) - if language_id is not None: - language_id = id_to_torch(language_id, cuda=use_cuda) + if language_id is not None: + language_id = id_to_torch(language_id, cuda=use_cuda) - if not isinstance(style_mel, dict): - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) - text_inputs = text_inputs.unsqueeze(0) - elif backend in ["tf", "tflite"]: - # TODO: handle speaker id for tf model - style_mel = numpy_to_tf(style_mel, tf.float32) - text_inputs = numpy_to_tf(text_inputs, tf.int32) - text_inputs = tf.expand_dims(text_inputs, 0) + if not isinstance(style_mel, dict): + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = text_inputs.unsqueeze(0) # synthesize voice - if backend == "torch": - outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) - model_outputs = outputs["model_outputs"] - model_outputs = model_outputs[0].data.cpu().numpy() - alignments = outputs["alignments"] - elif backend == "tf": - decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, text_inputs, CONFIG, speaker_id, style_mel - ) - model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf( - postnet_output, decoder_output, alignments, stop_tokens - ) - elif backend == "tflite": - decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite( - model, text_inputs, CONFIG, speaker_id, style_mel - ) - model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output) + outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) + model_outputs = outputs["model_outputs"] + model_outputs = model_outputs[0].data.cpu().numpy() + alignments = outputs["alignments"] + # convert outputs to numpy # plot results wav = None - if hasattr(model, "END2END") and model.END2END: - wav = model_outputs.squeeze(0) - else: + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] if use_griffin_lim: - wav = inv_spectrogram(model_outputs, ap, CONFIG) + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) # trim silence if do_trim_silence: - wav = trim_silence(wav, ap) + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs return_dict = { "wav": wav, "alignments": alignments, diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 537d2301..593372dc 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -1,276 +1 @@ -# -*- coding: utf-8 -*- -# adapted from https://github.com/keithito/tacotron - -import re -from typing import Dict, List - -import gruut -from gruut_ipa import IPA - -from TTS.tts.utils.text import cleaners -from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes -from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes -from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols - -# pylint: disable=unnecessary-comprehension -# Mappings from symbol to numeric ID and vice versa: -_symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = {i: s for i, s in enumerate(symbols)} - -_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} -_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} - -_symbols = symbols -_phonemes = phonemes - -# Regular expression matching text enclosed in curly braces: -_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") - -# Regular expression matching punctuations, ignoring empty space -PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+" - -# Table for str.translate to fix gruut/TTS phoneme mismatch -GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") - - -def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False): - """Convert graphemes to phonemes. - Parameters: - text (str): text to phonemize - language (str): language of the text - Returns: - ph (str): phonemes as a string seperated by "|" - ph = "ɪ|g|ˈ|z|æ|m|p|ə|l" - """ - - # TO REVIEW : How to have a good implementation for this? - if language == "zh-CN": - ph = chinese_text_to_phonemes(text) - return ph - - if language == "ja-jp": - ph = japanese_text_to_phonemes(text) - return ph - - if not gruut.is_language_supported(language): - raise ValueError(f" [!] Language {language} is not supported for phonemization.") - - # Use gruut for phonemization - ph_list = [] - for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes): - for word in sentence: - if word.is_break: - # Use actual character for break phoneme (e.g., comma) - if ph_list: - # Join with previous word - ph_list[-1].append(word.text) - else: - # First word is punctuation - ph_list.append([word.text]) - elif word.phonemes: - # Add phonemes for word - word_phonemes = [] - - for word_phoneme in word.phonemes: - if not keep_stress: - # Remove primary/secondary stress - word_phoneme = IPA.without_stress(word_phoneme) - - word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) - - if word_phoneme: - # Flatten phonemes - word_phonemes.extend(word_phoneme) - - if word_phonemes: - ph_list.append(word_phonemes) - - # Join and re-split to break apart dipthongs, suprasegmentals, etc. - ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list] - ph = "| ".join(ph_words) - - return ph - - -def intersperse(sequence, token): - result = [token] * (len(sequence) * 2 + 1) - result[1::2] = sequence - return result - - -def pad_with_eos_bos(phoneme_sequence, tp=None): - # pylint: disable=global-statement - global _phonemes_to_id, _bos, _eos - if tp: - _bos = tp["bos"] - _eos = tp["eos"] - _, _phonemes = make_symbols(**tp) - _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} - - return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] - - -def phoneme_to_sequence( - text: str, - cleaner_names: List[str], - language: str, - enable_eos_bos: bool = False, - custom_symbols: List[str] = None, - tp: Dict = None, - add_blank: bool = False, - use_espeak_phonemes: bool = False, -) -> List[int]: - """Converts a string of phonemes to a sequence of IDs. - If `custom_symbols` is provided, it will override the default symbols. - - Args: - text (str): string to convert to a sequence - cleaner_names (List[str]): names of the cleaner functions to run the text through - language (str): text language key for phonemization. - enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens. - tp (Dict): dictionary of character parameters to use a custom character set. - add_blank (bool): option to add a blank token between each token. - use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc - - Returns: - List[int]: List of integers corresponding to the symbols in the text - """ - # pylint: disable=global-statement - global _phonemes_to_id, _phonemes - - if custom_symbols is not None: - _phonemes = custom_symbols - elif tp: - _, _phonemes = make_symbols(**tp) - _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} - - sequence = [] - clean_text = _clean_text(text, cleaner_names) - to_phonemes = text2phone(clean_text, language, use_espeak_phonemes=use_espeak_phonemes) - if to_phonemes is None: - print("!! After phoneme conversion the result is None. -- {} ".format(clean_text)) - # iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation. - for phoneme in filter(None, to_phonemes.split("|")): - sequence += _phoneme_to_sequence(phoneme) - # Append EOS char - if enable_eos_bos: - sequence = pad_with_eos_bos(sequence, tp=tp) - if add_blank: - sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) - return sequence - - -def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None): - # pylint: disable=global-statement - """Converts a sequence of IDs back to a string""" - global _id_to_phonemes, _phonemes - if add_blank: - sequence = list(filter(lambda x: x != len(_phonemes), sequence)) - result = "" - - if custom_symbols is not None: - _phonemes = custom_symbols - elif tp: - _, _phonemes = make_symbols(**tp) - _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} - - for symbol_id in sequence: - if symbol_id in _id_to_phonemes: - s = _id_to_phonemes[symbol_id] - result += s - return result.replace("}{", " ") - - -def text_to_sequence( - text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False -) -> List[int]: - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - If `custom_symbols` is provided, it will override the default symbols. - - Args: - text (str): string to convert to a sequence - cleaner_names (List[str]): names of the cleaner functions to run the text through - tp (Dict): dictionary of character parameters to use a custom character set. - add_blank (bool): option to add a blank token between each token. - - Returns: - List[int]: List of integers corresponding to the symbols in the text - """ - # pylint: disable=global-statement - global _symbol_to_id, _symbols - - if custom_symbols is not None: - _symbols = custom_symbols - elif tp: - _symbols, _ = make_symbols(**tp) - _symbol_to_id = {s: i for i, s in enumerate(_symbols)} - - sequence = [] - - # Check for curly braces and treat their contents as ARPAbet: - while text: - m = _CURLY_RE.match(text) - if not m: - sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) - break - sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) - - if add_blank: - sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) - return sequence - - -def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None): - """Converts a sequence of IDs back to a string""" - # pylint: disable=global-statement - global _id_to_symbol, _symbols - if add_blank: - sequence = list(filter(lambda x: x != len(_symbols), sequence)) - - if custom_symbols is not None: - _symbols = custom_symbols - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - elif tp: - _symbols, _ = make_symbols(**tp) - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - - result = "" - for symbol_id in sequence: - if symbol_id in _id_to_symbol: - s = _id_to_symbol[symbol_id] - # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == "@": - s = "{%s}" % s[1:] - result += s - return result.replace("}{", " ") - - -def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception("Unknown cleaner: %s" % name) - text = cleaner(text) - return text - - -def _symbols_to_sequence(syms): - return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)] - - -def _phoneme_to_sequence(phons): - return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)] - - -def _arpabet_to_sequence(text): - return _symbols_to_sequence(["@" + s for s in text.split()]) - - -def _should_keep_symbol(s): - return s in _symbol_to_id and s not in ["~", "^", "_"] - - -def _should_keep_phoneme(p): - return p in _phonemes_to_id and p not in ["~", "^", "_"] +from TTS.tts.utils.text.tokenizer import TTSTokenizer diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py new file mode 100644 index 00000000..1b375e4f --- /dev/null +++ b/TTS/tts/utils/text/characters.py @@ -0,0 +1,468 @@ +from dataclasses import replace +from typing import Dict + +from TTS.tts.configs.shared_configs import CharactersConfig + + +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + +# DEFAULT SET OF GRAPHEMES +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_punctuations = "!'(),-.:;? " + + +# DEFAULT SET OF IPA PHONEMES +# Phonemes definition (All IPA characters) +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics + + +class BaseVocabulary: + """Base Vocabulary class. + + This class only needs a vocabulary dictionary without specifying the characters. + + Args: + vocab (Dict): A dictionary of characters and their corresponding indices. + """ + + def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None): + self.vocab = vocab + self.pad = pad + self.blank = blank + self.bos = bos + self.eos = eos + + @property + def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } + + @staticmethod + def init_from_config(config, **kwargs): + """Initialize from the given config.""" + if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: + return ( + BaseVocabulary( + config.characters.vocab_dict, + config.characters.pad, + config.characters.blank, + config.characters.bos, + config.characters.eos, + ), + config, + ) + return BaseVocabulary(**kwargs), config + + @property + def num_chars(self): + """Return number of tokens in the vocabulary.""" + return len(self._vocab) + + def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" + return self._id_to_char[idx] + + +class BaseCharacters: + """🐸BaseCharacters class + + Every new character class should inherit from this. + + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + """ + + def __init__( + self, + characters: str = None, + punctuations: str = None, + pad: str = None, + eos: str = None, + bos: str = None, + blank: str = None, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._create_vocab() + + @property + def blank(self): + return self._blank + + @blank.setter + def blank(self, blank): + self._blank = blank + self._create_vocab() + + @property + def vocab(self): + return self._vocab + + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension + } + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self._characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + self.vocab = _vocab + list(self._punctuations) + if self.is_unique: + duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} + assert ( + len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + ), f" [!] There are duplicate characters in the character set. {duplicates}" + + def char_to_id(self, char: str) -> int: + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + def print_log(self, level: int = 0): + """ + Prints the vocabulary in a nice format. + """ + indent = "\t" * level + print(f"{indent}| > Characters: {self._characters}") + print(f"{indent}| > Punctuations: {self._punctuations}") + print(f"{indent}| > Pad: {self._pad}") + print(f"{indent}| > EOS: {self._eos}") + print(f"{indent}| > BOS: {self._bos}") + print(f"{indent}| > Blank: {self._blank}") + print(f"{indent}| > Vocab: {self.vocab}") + print(f"{indent}| > Num chars: {self.num_chars}") + + @staticmethod + def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument + """Init your character class from a config. + + Implement this method for your subclass. + """ + # use character set from config + if config.characters is not None: + return BaseCharacters(**config.characters), config + # return default character set + characters = BaseCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=self._eos, + bos=self._bos, + blank=self._blank, + is_unique=self.is_unique, + is_sorted=self.is_sorted, + ) + + +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a IPAPhonemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + # band-aid for compatibility with old models + if "characters" in config and config.characters is not None: + if "phonemes" in config.characters and config.characters.phonemes is not None: + config.characters["characters"] = config.characters["phonemes"] + return ( + IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + # use character set from config + if config.characters is not None: + return IPAPhonemes(**config.characters), config + # return default character set + characters = IPAPhonemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a Graphemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + if config.characters is not None: + # band-aid for compatibility with old models + if "phonemes" in config.characters: + return ( + Graphemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + return Graphemes(**config.characters), config + characters = Graphemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +if __name__ == "__main__": + gr = Graphemes() + ph = IPAPhonemes() + gr.print_log() + ph.print_log() diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py index 29cac160..727c881e 100644 --- a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py +++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -19,7 +19,7 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str: return phoneme + tone -def chinese_text_to_phonemes(text: str) -> str: +def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str: tokenized_text = jieba.cut(text, HMM=False) tokenized_text = " ".join(tokenized_text) pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) @@ -34,4 +34,4 @@ def chinese_text_to_phonemes(text: str) -> str: else: # is ponctuation or other results += list(token) - return "|".join(results) + return seperator.join(results) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index f3ffa478..f02f8fb4 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -1,12 +1,16 @@ +"""Set of default text cleaners""" +# TODO: pick the cleaner for languages dynamically + import re from anyascii import anyascii from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text -from .abbreviations import abbreviations_en, abbreviations_fr -from .number_norm import normalize_numbers -from .time import expand_time_english +from .english.abbreviations import abbreviations_en +from .english.number_norm import normalize_numbers as en_normalize_numbers +from .english.time_norm import expand_time_english +from .french.abbreviations import abbreviations_fr # Regular expression matching whitespace: _whitespace_re = re.compile(r"\s+") @@ -22,10 +26,6 @@ def expand_abbreviations(text, lang="en"): return text -def expand_numbers(text): - return normalize_numbers(text) - - def lowercase(text): return text.lower() @@ -92,7 +92,17 @@ def english_cleaners(text): # text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) - text = expand_numbers(text) + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def phoneme_cleaners(text): + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = en_normalize_numbers(text) text = expand_abbreviations(text) text = replace_symbols(text) text = remove_aux_symbols(text) @@ -126,17 +136,6 @@ def chinese_mandarin_cleaners(text: str) -> str: return text -def phoneme_cleaners(text): - """Pipeline for phonemes mode, including number and abbreviation expansion.""" - text = expand_numbers(text) - # text = convert_to_ascii(text) - text = expand_abbreviations(text) - text = replace_symbols(text) - text = remove_aux_symbols(text) - text = collapse_whitespace(text) - return text - - def multilingual_cleaners(text): """Pipeline for multilingual text""" text = lowercase(text) diff --git a/TTS/tts/tf/__init__.py b/TTS/tts/utils/text/english/__init__.py similarity index 100% rename from TTS/tts/tf/__init__.py rename to TTS/tts/utils/text/english/__init__.py diff --git a/TTS/tts/utils/text/english/abbreviations.py b/TTS/tts/utils/text/english/abbreviations.py new file mode 100644 index 00000000..cd93c13c --- /dev/null +++ b/TTS/tts/utils/text/english/abbreviations.py @@ -0,0 +1,26 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in english: +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] diff --git a/TTS/tts/utils/text/number_norm.py b/TTS/tts/utils/text/english/number_norm.py similarity index 100% rename from TTS/tts/utils/text/number_norm.py rename to TTS/tts/utils/text/english/number_norm.py diff --git a/TTS/tts/utils/text/time.py b/TTS/tts/utils/text/english/time_norm.py similarity index 100% rename from TTS/tts/utils/text/time.py rename to TTS/tts/utils/text/english/time_norm.py diff --git a/TTS/tts/tf/layers/tacotron/__init__.py b/TTS/tts/utils/text/french/__init__.py similarity index 100% rename from TTS/tts/tf/layers/tacotron/__init__.py rename to TTS/tts/utils/text/french/__init__.py diff --git a/TTS/tts/utils/text/abbreviations.py b/TTS/tts/utils/text/french/abbreviations.py similarity index 66% rename from TTS/tts/utils/text/abbreviations.py rename to TTS/tts/utils/text/french/abbreviations.py index 7e44b90c..f580dfed 100644 --- a/TTS/tts/utils/text/abbreviations.py +++ b/TTS/tts/utils/text/french/abbreviations.py @@ -1,30 +1,5 @@ import re -# List of (regular expression, replacement) pairs for abbreviations in english: -abbreviations_en = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ] -] - # List of (regular expression, replacement) pairs for abbreviations in french: abbreviations_fr = [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py new file mode 100644 index 00000000..5dc117c4 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -0,0 +1,57 @@ +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut +from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer +from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)} + + +ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) +GRUUT_LANGS = list(Gruut.supported_languages()) + + +# Dict setting default phonemizers for each language +DEF_LANG_TO_PHONEMIZER = { + "ja-jp": JA_JP_Phonemizer.name(), + "zh-cn": ZH_CN_Phonemizer.name(), +} + + +# Add Gruut languages +_ = [Gruut.name()] * len(GRUUT_LANGS) +_new_dict = dict(list(zip(GRUUT_LANGS, _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + + +# Add ESpeak languages and override any existing ones +_ = [ESpeak.name()] * len(ESPEAK_LANGS) +_new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + +DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] + + +def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: + """Initiate a phonemizer by name + + Args: + name (str): + Name of the phonemizer that should match `phonemizer.name()`. + + kwargs (dict): + Extra keyword arguments that should be passed to the phonemizer. + """ + if name == "espeak": + return ESpeak(**kwargs) + if name == "gruut": + return Gruut(**kwargs) + if name == "zh_cn_phonemizer": + return ZH_CN_Phonemizer(**kwargs) + if name == "ja_jp_phonemizer": + return JA_JP_Phonemizer(**kwargs) + raise ValueError(f"Phonemizer {name} not found") + + +if __name__ == "__main__": + print(DEF_LANG_TO_PHONEMIZER) diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py new file mode 100644 index 00000000..08fa8e13 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -0,0 +1,141 @@ +import abc +from typing import List, Tuple + +from TTS.tts.utils.text.punctuation import Punctuation + + +class BasePhonemizer(abc.ABC): + """Base phonemizer class + + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + + Args: + language (str): + Language used by the phonemizer. + + punctuations (List[str]): + List of punctuation marks to be preserved. + + keep_puncs (bool): + Whether to preserve punctuation marks or not. + """ + + def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False): + + # ensure the backend is installed on the system + if not self.is_available(): + raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover + + # ensure the backend support the requested language + self._language = self._init_language(language) + + # setup punctuation processing + self._keep_puncs = keep_puncs + self._punctuator = Punctuation(punctuations) + + def _init_language(self, language): + """Language initialization + + This method may be overloaded in child classes (see Segments backend) + + """ + if not self.is_supported_language(language): + raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend") + return language + + @property + def language(self): + """The language code configured to be used for phonemization""" + return self._language + + @staticmethod + @abc.abstractmethod + def name(): + """The name of the backend""" + ... + + @classmethod + @abc.abstractmethod + def is_available(cls): + """Returns True if the backend is installed, False otherwise""" + ... + + @classmethod + @abc.abstractmethod + def version(cls): + """Return the backend version as a tuple (major, minor, patch)""" + ... + + @staticmethod + @abc.abstractmethod + def supported_languages(): + """Return a dict of language codes -> name supported by the backend""" + ... + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return language in self.supported_languages() + + @abc.abstractmethod + def _phonemize(self, text, separator): + """The main phonemization method""" + + def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: + """Preprocess the text before phonemization + + 1. remove spaces + 2. remove punctuation + + Override this if you need a different behaviour + """ + text = text.strip() + if self._keep_puncs: + # a tuple (text, punctuation marks) + return self._punctuator.strip_to_restore(text) + return [self._punctuator.strip(text)], [] + + def _phonemize_postprocess(self, phonemized, punctuations) -> str: + """Postprocess the raw phonemized output + + Override this if you need a different behaviour + """ + if self._keep_puncs: + return self._punctuator.restore(phonemized, punctuations)[0] + return phonemized[0] + + def phonemize(self, text: str, separator="|") -> str: + """Returns the `text` phonemized for the given language + + Args: + text (str): + Text to be phonemized. + + separator (str): + string separator used between phonemes. Default to '_'. + + Returns: + (str): Phonemized text + """ + text, punctuations = self._phonemize_preprocess(text) + phonemized = [] + for t in text: + p = self._phonemize(t, separator) + phonemized.append(p) + phonemized = self._phonemize_postprocess(phonemized, punctuations) + return phonemized + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > phoneme language: {self.language}") + print(f"{indent}| > phoneme backend: {self.name()}") diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py new file mode 100644 index 00000000..442dcef2 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -0,0 +1,225 @@ +import logging +import subprocess +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + + +def is_tool(name): + from shutil import which + + return which(name) is not None + + +# priority: espeakng > espeak +if is_tool("espeak-ng"): + _DEF_ESPEAK_LIB = "espeak-ng" +elif is_tool("espeak"): + _DEF_ESPEAK_LIB = "espeak" +else: + _DEF_ESPEAK_LIB = None + + +def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: + """Run espeak with the given arguments.""" + cmd = [ + espeak_lib, + "-q", + "-b", + "1", # UTF8 text encoding + ] + cmd.extend(args) + logging.debug("espeakng: executing %s", repr(cmd)) + + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as p: + res = iter(p.stdout.readline, b"") + if not sync: + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + return res + res2 = [] + for line in res: + res2.append(line) + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + p.wait() + return res2 + + +class ESpeak(BasePhonemizer): + """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P + + Args: + language (str): + Valid language code for the used backend. + + backend (str): + Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically + prefering `espeak-ng` over `espeak`. Defaults to None. + + punctuations (str): + Characters to be treated as punctuation. Defaults to Punctuation.default_puncs(). + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to True. + + Example: + + >>> from TTS.tts.utils.text.phonemizers import ESpeak + >>> phonemizer = ESpeak("tr") + >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|") + 'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.' + + """ + + _ESPEAK_LIB = _DEF_ESPEAK_LIB + + def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): + if self._ESPEAK_LIB is None: + raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.") + self.backend = self._ESPEAK_LIB + + # band-aid for backwards compatibility + if language == "en": + language = "en-us" + + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + if backend is not None: + self.backend = backend + + @property + def backend(self): + return self._ESPEAK_LIB + + @backend.setter + def backend(self, backend): + if backend not in ["espeak", "espeak-ng"]: + raise Exception("Unknown backend: %s" % backend) + self._ESPEAK_LIB = backend + # skip first two characters of the retuned text + # "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # ^^ + self.num_skip_chars = 2 + if backend == "espeak-ng": + # skip the first character of the retuned text + # "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # ^ + self.num_skip_chars = 1 + + def auto_set_espeak_lib(self) -> None: + if is_tool("espeak-ng"): + self._ESPEAK_LIB = "espeak-ng" + elif is_tool("espeak"): + self._ESPEAK_LIB = "espeak" + else: + raise Exception("Cannot set backend automatically. espeak-ng or espeak not found") + + @staticmethod + def name(): + return "espeak" + + def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: + """Convert input text to phonemes. + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + # set arguments + args = ["-v", f"{self._language}"] + # espeak and espeak-ng parses `ipa` differently + if tie: + # use '͡' between phonemes + if self.backend == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + # split with '_' + if self.backend == "espeak": + args.append("--ipa=3") + else: + args.append("--ipa=1") + if tie: + args.append("--tie=%s" % tie) + + args.append('"' + text + '"') + # compute phonemes + phonemes = "" + for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): + logging.debug("line: %s", repr(line)) + phonemes += line.decode("utf8").strip()[self.num_skip_chars :] # skip initial redundant characters + return phonemes.replace("_", separator) + + def _phonemize(self, text, separator=None): + return self.phonemize_espeak(text, separator, tie=False) + + @staticmethod + def supported_languages() -> Dict: + """Get a dictionary of supported languages. + + Returns: + Dict: Dictionary of language codes. + """ + if _DEF_ESPEAK_LIB is None: + return {} + args = ["--voices"] + langs = {} + count = 0 + for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): + line = line.decode("utf8").strip() + if count > 0: + cols = line.split() + lang_code = cols[1] + lang_name = cols[3] + langs[lang_code] = lang_name + logging.debug("line: %s", repr(line)) + count += 1 + return langs + + def version(self) -> str: + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + args = ["--version"] + for line in _espeak_exe(self.backend, args, sync=True): + version = line.decode("utf8").strip().split()[2] + logging.debug("line: %s", repr(line)) + return version + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return is_tool("espeak") or is_tool("espeak-ng") + + +if __name__ == "__main__": + e = ESpeak(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = ESpeak(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = ESpeak(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py new file mode 100644 index 00000000..f3e9c9ab --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -0,0 +1,151 @@ +import importlib +from typing import List + +import gruut +from gruut_ipa import IPA + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +# Table for str.translate to fix gruut/TTS phoneme mismatch +GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") + + +class Gruut(BasePhonemizer): + """Gruut wrapper for G2P + + Args: + language (str): + Valid language code for the used backend. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`. + + keep_puncs (bool): + If true, keep the punctuations after phonemization. Defaults to True. + + use_espeak_phonemes (bool): + If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False. + + keep_stress (bool): + If true, keep the stress characters after phonemization. Defaults to False. + + Example: + + >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + >>> phonemizer = Gruut('en-us') + >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|") + 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?' + """ + + def __init__( + self, + language: str, + punctuations=Punctuation.default_puncs(), + keep_puncs=True, + use_espeak_phonemes=False, + keep_stress=False, + ): + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + self.use_espeak_phonemes = use_espeak_phonemes + self.keep_stress = keep_stress + + @staticmethod + def name(): + return "gruut" + + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument + """Convert input text to phonemes. + + Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters + that constitude a single sound. + + It doesn't affect 🐸TTS since it individually converts each character to token IDs. + + Examples:: + "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ` + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + ph_list = [] + for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes): + for word in sentence: + if word.is_break: + # Use actual character for break phoneme (e.g., comma) + if ph_list: + # Join with previous word + ph_list[-1].append(word.text) + else: + # First word is punctuation + ph_list.append([word.text]) + elif word.phonemes: + # Add phonemes for word + word_phonemes = [] + + for word_phoneme in word.phonemes: + if not self.keep_stress: + # Remove primary/secondary stress + word_phoneme = IPA.without_stress(word_phoneme) + + word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) + + if word_phoneme: + # Flatten phonemes + word_phonemes.extend(word_phoneme) + + if word_phonemes: + ph_list.append(word_phonemes) + + ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list] + ph = f"{separator} ".join(ph_words) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_gruut(text, separator, tie=False) + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return gruut.is_language_supported(language) + + @staticmethod + def supported_languages() -> List: + """Get a dictionary of supported languages. + + Returns: + List: List of language codes. + """ + return list(gruut.get_supported_languages()) + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return gruut.__version__ + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return importlib.util.find_spec("gruut") is not None + + +if __name__ == "__main__": + e = Gruut(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = Gruut(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = Gruut(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how, are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py new file mode 100644 index 00000000..60b965f9 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -0,0 +1,72 @@ +from typing import Dict + +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" + +_TRANS_TABLE = {"、": ","} + + +def trans(text): + for i, j in _TRANS_TABLE.items(): + text = text.replace(i, j) + return text + + +class JA_JP_Phonemizer(BasePhonemizer): + """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` + + TODO: someone with JA knowledge should check this implementation + + Example: + + >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer + >>> phonemizer = JA_JP_Phonemizer() + >>> phonemizer.phonemize("どちらに行きますか?", separator="|") + 'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?' + + """ + + language = "ja-jp" + + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ja_jp_phonemizer" + + def _phonemize(self, text: str, separator: str = "|") -> str: + ph = japanese_text_to_phonemes(text) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator="|") -> str: + """Custom phonemize for JP_JA + + Skip pre-post processing steps used by the other phonemizers. + """ + return self._phonemize(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"ja-jp": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "これは、電話をかけるための私の日本語の例のテキストです。" +# e = JA_JP_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py new file mode 100644 index 00000000..e36b0a2a --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name + + +class MultiPhonemizer: + """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages + + Args: + custom_lang_to_phonemizer (Dict): + Custom phonemizer mapping if you want to change the defaults. In the format of + `{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`. + + TODO: find a way to pass custom kwargs to the phonemizers + """ + + lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER + language = "multi-lingual" + + def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value + self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer) + self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) + + @staticmethod + def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: + lang_to_phonemizer = {} + for k, v in lang_to_phonemizer_name.items(): + phonemizer = get_phonemizer_by_name(v, language=k) + lang_to_phonemizer[k] = phonemizer + return lang_to_phonemizer + + @staticmethod + def name(): + return "multi-phonemizer" + + def phonemize(self, text, language, separator="|"): + return self.lang_to_phonemizer[language].phonemize(text, separator) + + def supported_languages(self) -> List: + return list(self.lang_to_phonemizer_name.keys()) + + +# if __name__ == "__main__": +# texts = { +# "tr": "Merhaba, bu Türkçe bit örnek!", +# "en-us": "Hello, this is English example!", +# "de": "Hallo, das ist ein Deutches Beipiel!", +# "zh-cn": "这是中国的例子", +# } +# phonemes = {} +# ph = MultiPhonemizer() +# for lang, text in texts.items(): +# phoneme = ph.phonemize(text, lang) +# phonemes[lang] = phoneme +# print(phonemes) diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py new file mode 100644 index 00000000..5a4a5591 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class ZH_CN_Phonemizer(BasePhonemizer): + """🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Mandarin knowledge should check this implementation + """ + + language = "zh-cn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "zh_cn_phonemizer" + + @staticmethod + def phonemize_zh_cn(text: str, separator: str = "|") -> str: + ph = chinese_text_to_phonemes(text, separator) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_zh_cn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"zh-cn": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "这是,样本中文。" +# e = ZH_CN_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py new file mode 100644 index 00000000..b2a058bb --- /dev/null +++ b/TTS/tts/utils/text/punctuation.py @@ -0,0 +1,172 @@ +import collections +import re +from enum import Enum + +import six + +_DEF_PUNCS = ';:,.!?¡¿—…"«»“”' + +_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"]) + + +class PuncPosition(Enum): + """Enum for the punctuations positions""" + + BEGIN = 0 + END = 1 + MIDDLE = 2 + ALONE = 3 + + +class Punctuation: + """Handle punctuations in text. + + Just strip punctuations from text or strip and restore them later. + + Args: + puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' + """ + + def __init__(self, puncs: str = _DEF_PUNCS): + self.puncs = puncs + + @staticmethod + def default_puncs(): + """Return default set of punctuations.""" + return _DEF_PUNCS + + @property + def puncs(self): + return self._puncs + + @puncs.setter + def puncs(self, value): + if not isinstance(value, six.string_types): + raise ValueError("[!] Punctuations must be of type str.") + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") + + def strip(self, text): + """Remove all the punctuations by replacing with `space`. + + Args: + text (str): The text to be processed. + + Example:: + + "This is. example !" -> "This is example " + """ + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() + + def strip_to_restore(self, text): + """Remove punctuations from text to restore them later. + + Args: + text (str): The text to be processed. + + Examples :: + + "This is. example !" -> [["This is", "example"], [".", "!"]] + + """ + text, puncs = self._strip_to_restore(text) + return text, puncs + + def _strip_to_restore(self, text): + """Auxiliary method for Punctuation.preserve()""" + matches = list(re.finditer(self.puncs_regular_exp, text)) + if not matches: + return [text], [] + # the text is only punctuations + if len(matches) == 1 and matches[0].group() == text: + return [], [_PUNC_IDX(text, PuncPosition.ALONE)] + # build a punctuation map to be used later to restore punctuations + puncs = [] + for match in matches: + position = PuncPosition.MIDDLE + if match == matches[0] and text.startswith(match.group()): + position = PuncPosition.BEGIN + elif match == matches[-1] and text.endswith(match.group()): + position = PuncPosition.END + puncs.append(_PUNC_IDX(match.group(), position)) + # convert str text to a List[str], each item is separated by a punctuation + splitted_text = [] + for idx, punc in enumerate(puncs): + split = text.split(punc.punc) + prefix, suffix = split[0], punc.punc.join(split[1:]) + splitted_text.append(prefix) + # if the text does not end with a punctuation, add it to the last item + if idx == len(puncs) - 1 and len(suffix) > 0: + splitted_text.append(suffix) + text = suffix + return splitted_text, puncs + + @classmethod + def restore(cls, text, puncs): + """Restore punctuation in a text. + + Args: + text (str): The text to be processed. + puncs (List[str]): The list of punctuations map to be used for restoring. + + Examples :: + + ['This is', 'example'], ['.', '!'] -> "This is. example!" + + """ + return cls._restore(text, puncs, 0) + + @classmethod + def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements + """Auxiliary method for Punctuation.restore()""" + if not puncs: + return text + + # nothing have been phonemized, returns the puncs alone + if not text: + return ["".join(m.mark for m in puncs)] + + current = puncs[0] + + if current.position == PuncPosition.BEGIN: + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num) + + if current.position == PuncPosition.END: + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) + + if current.position == PuncPosition.ALONE: + return [current.mark] + cls._restore(text, puncs[1:], num + 1) + + # POSITION == MIDDLE + if len(text) == 1: # pragma: nocover + # a corner case where the final part of an intermediate + # mark (I) has not been phonemized + return cls._restore([text[0] + current.punc], puncs[1:], num) + + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) + + +# if __name__ == "__main__": +# punc = Punctuation() +# text = "This is. This is, example!" + +# print(punc.strip(text)) + +# split_text, puncs = punc.strip_to_restore(text) +# print(split_text, " ---- ", puncs) + +# restored_text = punc.restore(split_text, puncs) +# print(restored_text) diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py deleted file mode 100644 index cb708958..00000000 --- a/TTS/tts/utils/text/symbols.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Defines the set of symbols used in text input to the model. - -The default is a set of ASCII characters that works well for English or text that has been run -through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -""" - - -def make_symbols( - characters, - phonemes=None, - punctuations="!'(),-.:;? ", - pad="_", - eos="~", - bos="^", - unique=True, -): # pylint: disable=redefined-outer-name - """Function to create symbols and phonemes - TODO: create phonemes_to_id and symbols_to_id dicts here.""" - _symbols = list(characters) - _symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols - _symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols - _symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols - _phonemes = None - if phonemes is not None: - _phonemes_sorted = ( - sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) - ) # this is to keep previous models compatible. - # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - # _arpabet = ["@" + s for s in _phonemes_sorted] - # Export all symbols: - _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) - # _symbols += _arpabet - return _symbols, _phonemes - - -_pad = "_" -_eos = "~" -_bos = "^" -_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? " -_punctuations = "!'(),-.:;? " - -# Phonemes definition (All IPA characters) -_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" -_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" -_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" -_suprasegmentals = "ˈˌːˑ" -_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" -_diacrilics = "ɚ˞ɫ" -_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics - -symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) - -# Generate ALIEN language -# from random import shuffle -# shuffle(phonemes) - - -def parse_symbols(): - return { - "pad": _pad, - "eos": _eos, - "bos": _bos, - "characters": _characters, - "punctuations": _punctuations, - "phonemes": _phonemes, - } - - -if __name__ == "__main__": - print(" > TTS symbols {}".format(len(symbols))) - print(symbols) - print(" > TTS phonemes {}".format(len(phonemes))) - print("".join(sorted(phonemes))) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py new file mode 100644 index 00000000..f0d85a44 --- /dev/null +++ b/TTS/tts/utils/text/tokenizer.py @@ -0,0 +1,205 @@ +from typing import Callable, Dict, List, Union + +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.utils.generic_utils import get_import_path, import_class + + +class TTSTokenizer: + """🐸TTS tokenizer to convert input characters to token IDs and back. + + Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later. + + Args: + use_phonemes (bool): + Whether to use phonemes instead of characters. Defaults to False. + + characters (Characters): + A Characters object to use for character-to-ID and ID-to-character mappings. + + text_cleaner (callable): + A function to pre-process the text before tokenization and phonemization. Defaults to None. + + phonemizer (Phonemizer): + A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + + Example: + + >>> from TTS.tts.utils.text.tokenizer import TTSTokenizer + >>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + >>> text = "Hello world!" + >>> ids = tokenizer.text_to_ids(text) + >>> text_hat = tokenizer.ids_to_text(ids) + >>> assert text == text_hat + """ + + def __init__( + self, + use_phonemes=False, + text_cleaner: Callable = None, + characters: "BaseCharacters" = None, + phonemizer: Union["Phonemizer", Dict] = None, + add_blank: bool = False, + use_eos_bos=False, + ): + self.text_cleaner = text_cleaner + self.use_phonemes = use_phonemes + self.add_blank = add_blank + self.use_eos_bos = use_eos_bos + self.characters = characters + self.not_found_characters = [] + self.phonemizer = phonemizer + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, new_characters): + self._characters = new_characters + self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None + self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None + + def encode(self, text: str) -> List[int]: + """Encodes a string of text as a sequence of IDs.""" + token_ids = [] + for char in text: + try: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + except KeyError: + # discard but store not found characters + if char not in self.not_found_characters: + self.not_found_characters.append(char) + print(text) + print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + return token_ids + + def decode(self, token_ids: List[int]) -> str: + """Decodes a sequence of IDs to a string of text.""" + text = "" + for token_id in token_ids: + text += self.characters.id_to_char(token_id) + return text + + def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument + """Converts a string of text to a sequence of token IDs. + + Args: + text(str): + The text to convert to token IDs. + + language(str): + The language code of the text. Defaults to None. + + TODO: + - Add support for language-specific processing. + + 1. Text normalizatin + 2. Phonemization (if use_phonemes is True) + 3. Add blank char between characters + 4. Add BOS and EOS characters + 5. Text to token IDs + """ + # TODO: text cleaner should pick the right routine based on the language + if self.text_cleaner is not None: + text = self.text_cleaner(text) + if self.use_phonemes: + text = self.phonemizer.phonemize(text, separator="") + if self.add_blank: + text = self.intersperse_blank_char(text, True) + if self.use_eos_bos: + text = self.pad_with_bos_eos(text) + return self.encode(text) + + def ids_to_text(self, id_sequence: List[int]) -> str: + """Converts a sequence of token IDs to a string of text.""" + return self.decode(id_sequence) + + def pad_with_bos_eos(self, char_sequence: List[str]): + """Pads a sequence with the special BOS and EOS characters.""" + return [self.characters.bos] + list(char_sequence) + [self.characters.eos] + + def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ + char_to_use = self.characters.blank if use_blank_char else self.characters.pad + result = [char_to_use] * (len(char_sequence) * 2 + 1) + result[1::2] = char_sequence + return result + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > add_blank: {self.add_blank}") + print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") + print(f"{indent}| > use_phonemes: {self.use_phonemes}") + if self.use_phonemes: + print(f"{indent}| > phonemizer:") + self.phonemizer.print_logs(level + 1) + if len(self.not_found_characters) > 0: + print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + for char in self.not_found_characters: + print(f"{indent}| > {char}") + + @staticmethod + def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): + """Init Tokenizer object from config + + Args: + config (Coqpit): Coqpit model config. + characters (BaseCharacters): Defines the model character set. If not set, use the default options based on + the config values. Defaults to None. + """ + # init cleaners + text_cleaner = None + if isinstance(config.text_cleaner, (str, list)): + text_cleaner = getattr(cleaners, config.text_cleaner) + + # init characters + if characters is None: + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config + else: + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + + else: + characters, new_config = characters.init_from_config(config) + + # set characters class + new_config.characters.characters_class = get_import_path(characters) + + # init phonemizer + phonemizer = None + if config.use_phonemes: + phonemizer_kwargs = {"language": config.phoneme_language} + + if "phonemizer" in config and config.phonemizer: + phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) + else: + try: + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) + except KeyError as e: + raise ValueError( + f"""No phonemizer found for language {config.phoneme_language}. + You may need to install a third party library for this language.""" + ) from e + + return ( + TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ), + new_config, + ) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index ff71958e..78c12981 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -4,8 +4,6 @@ import matplotlib.pyplot as plt import numpy as np import torch -from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme - matplotlib.use("Agg") @@ -89,12 +87,46 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False) return fig +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output, text, hop_length, CONFIG, + tokenizer, stop_tokens=None, decoder_output=None, output_path=None, @@ -117,14 +149,8 @@ def visualize( plt.ylabel("Encoder timestamp", fontsize=label_fontsize) # compute phoneme representation and back if CONFIG.use_phonemes: - seq = phoneme_to_sequence( - text, - [CONFIG.text_cleaner], - CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if "characters" in CONFIG.keys() else None, - ) - text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None) + seq = tokenizer.text_to_ids(text) + text = tokenizer.ids_to_text(seq) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 25f93c34..d0777c11 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method ) M = o[:, :, :, 0] P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) if self.power is not None: - S = S ** self.power + S = S**self.power if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) @@ -239,6 +239,12 @@ class AudioProcessor(object): mel_fmax (int, optional): maximum filter frequency for computing melspectrograms. Defaults to None. + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20. @@ -300,6 +306,8 @@ class AudioProcessor(object): max_norm=None, mel_fmin=None, mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, spec_gain=20, stft_pad_mode="reflect", clip_norm=True, @@ -333,6 +341,8 @@ class AudioProcessor(object): self.symmetric_norm = symmetric_norm self.mel_fmin = mel_fmin or 0 self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax self.spec_gain = float(spec_gain) self.stft_pad_mode = stft_pad_mode self.max_norm = 1.0 if max_norm is None else float(max_norm) @@ -379,6 +389,12 @@ class AudioProcessor(object): self.clip_norm = None self.symmetric_norm = None + @staticmethod + def init_from_config(config: "Coqpit", verbose=True): + if "audio" in config: + return AudioProcessor(verbose=verbose, **config.audio) + return AudioProcessor(verbose=verbose, **config) + ### setting up the parameters ### def _build_mel_basis( self, @@ -634,8 +650,8 @@ class AudioProcessor(object): S = self._db_to_amp(S) # Reconstruct phase if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) - return self._griffin_lim(S ** self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" @@ -643,8 +659,8 @@ class AudioProcessor(object): S = self._db_to_amp(D) S = self._mel_to_linear(S) # Convert back to linear if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) - return self._griffin_lim(S ** self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: """Convert a full scale linear spectrogram output of a network to a melspectrogram. @@ -720,11 +736,12 @@ class AudioProcessor(object): >>> WAV_FILE = filename = librosa.util.example_audio_file() >>> from TTS.config import BaseAudioConfig >>> from TTS.utils.audio import AudioProcessor - >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> conf = BaseAudioConfig(pitch_fmax=8000) >>> ap = AudioProcessor(**conf) >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> pitch = ap.compute_f0(wav) """ + assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." # align F0 length to the spectrogram length if len(x) % self.hop_length == 0: x = np.pad(x, (0, self.hop_length // 2), mode="reflect") @@ -732,7 +749,7 @@ class AudioProcessor(object): f0, t = pw.dio( x.astype(np.double), fs=self.sample_rate, - f0_ceil=self.mel_fmax, + f0_ceil=self.pitch_fmax, frame_period=1000 * self.hop_length / self.sample_rate, ) f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) @@ -781,7 +798,7 @@ class AudioProcessor(object): @staticmethod def _rms_norm(wav, db_level=-27): r = 10 ** (db_level / 20) - a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2)) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) return wav * a def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: @@ -853,7 +870,7 @@ class AudioProcessor(object): @staticmethod def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: - mu = 2 ** qc - 1 + mu = 2**qc - 1 # wav_abs = np.minimum(np.abs(wav), 1.0) signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) # Quantize signal to the specified number of levels. @@ -865,13 +882,13 @@ class AudioProcessor(object): @staticmethod def mulaw_decode(wav, qc): """Recovers waveform from quantized values.""" - mu = 2 ** qc - 1 + mu = 2**qc - 1 x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) return x @staticmethod def encode_16bits(x): - return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) @staticmethod def quantize(x: np.ndarray, bits: int) -> np.ndarray: @@ -884,12 +901,12 @@ class AudioProcessor(object): Returns: np.ndarray: Quantized waveform. """ - return (x + 1.0) * (2 ** bits - 1) / 2 + return (x + 1.0) * (2**bits - 1) / 2 @staticmethod def dequantize(x, bits): """Dequantize a waveform from the given number of bits.""" - return 2 * x / (2 ** bits - 1) - 1 + return 2 * x / (2**bits - 1) - 1 def _log(x, base): diff --git a/TTS/utils/download.py b/TTS/utils/download.py index 241a106b..de9b31a7 100644 --- a/TTS/utils/download.py +++ b/TTS/utils/download.py @@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> while True: # Read by chunk to avoid filling memory - chunk = file_obj.read(1024 ** 2) + chunk = file_obj.read(1024**2) if not chunk: break hash_func.update(chunk) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 6504cca6..69609bcb 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -95,6 +95,33 @@ def find_module(module_path: str, module_name: str) -> object: return getattr(module, class_name) +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel diff --git a/TTS/utils/logging/__init__.py b/TTS/utils/logging/__init__.py deleted file mode 100644 index 43fbf6f1..00000000 --- a/TTS/utils/logging/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from TTS.utils.logging.console_logger import ConsoleLogger -from TTS.utils.logging.tensorboard_logger import TensorboardLogger -from TTS.utils.logging.wandb_logger import WandbLogger - - -def init_dashboard_logger(config): - if config.dashboard_logger == "tensorboard": - dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model) - - elif config.dashboard_logger == "wandb": - project_name = config.model - if config.project_name: - project_name = config.project_name - - dashboard_logger = WandbLogger( - project=project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) - - dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) - - return dashboard_logger diff --git a/TTS/utils/logging/console_logger.py b/TTS/utils/logging/console_logger.py deleted file mode 100644 index 74371342..00000000 --- a/TTS/utils/logging/console_logger.py +++ /dev/null @@ -1,105 +0,0 @@ -import datetime - -from TTS.utils.io import AttrDict - -tcolors = AttrDict( - { - "OKBLUE": "\033[94m", - "HEADER": "\033[95m", - "OKGREEN": "\033[92m", - "WARNING": "\033[93m", - "FAIL": "\033[91m", - "ENDC": "\033[0m", - "BOLD": "\033[1m", - "UNDERLINE": "\033[4m", - } -) - - -class ConsoleLogger: - def __init__(self): - # TODO: color code for value changes - # use these to compare values between iterations - self.old_train_loss_dict = None - self.old_epoch_loss_dict = None - self.old_eval_loss_dict = None - - # pylint: disable=no-self-use - def get_time(self): - now = datetime.datetime.now() - return now.strftime("%Y-%m-%d %H:%M:%S") - - def print_epoch_start(self, epoch, max_epoch, output_path=None): - print( - "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), - flush=True, - ) - if output_path is not None: - print(f" --> {output_path}") - - def print_train_start(self): - print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - - def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict): - indent = " | > " - print() - log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( - tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC - ) - for key, value in loss_dict.items(): - if f"avg_{key}" in avg_loss_dict.keys(): - # print the avg value if given - if isinstance(value, float) and round(value, 5) == 0: - # do not round the number if it is zero when rounded - log_text += "{}{}: {} ({})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - # print the rounded value - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - if isinstance(value, float) and round(value, 5) == 0: - log_text += "{}{}: {} \n".format(indent, key, value) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - print(log_text, flush=True) - - # pylint: disable=unused-argument - def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): - indent = " | > " - log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" - for key, value in print_dict.items(): - log_text += "{}{}: {:.5f}\n".format(indent, key, value) - print(log_text, flush=True) - - def print_eval_start(self): - print(f"\n{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n") - - def print_eval_step(self, step, loss_dict, avg_loss_dict): - indent = " | > " - log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" - for key, value in loss_dict.items(): - # print the avg value if given - if f"avg_{key}" in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - print(log_text, flush=True) - - def print_epoch_end(self, epoch, avg_loss_dict): - indent = " | > " - log_text = "\n {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) - for key, value in avg_loss_dict.items(): - # print the avg value if given - color = "" - sign = "+" - diff = 0 - if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: - diff = value - self.old_eval_loss_dict[key] - if diff < 0: - color = tcolors.OKGREEN - sign = "" - elif diff > 0: - color = tcolors.FAIL - sign = "+" - log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) - self.old_eval_loss_dict = avg_loss_dict - print(log_text, flush=True) diff --git a/TTS/utils/logging/tensorboard_logger.py b/TTS/utils/logging/tensorboard_logger.py deleted file mode 100644 index 812683f7..00000000 --- a/TTS/utils/logging/tensorboard_logger.py +++ /dev/null @@ -1,79 +0,0 @@ -import traceback - -from tensorboardX import SummaryWriter - - -class TensorboardLogger(object): - def __init__(self, log_dir, model_name): - self.model_name = model_name - self.writer = SummaryWriter(log_dir) - - def model_weights(self, model, step): - layer_num = 1 - for name, param in model.named_parameters(): - if param.numel() == 1: - self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) - else: - self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) - self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) - self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) - self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) - self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) - layer_num += 1 - - def dict_to_tb_scalar(self, scope_name, stats, step): - for key, value in stats.items(): - self.writer.add_scalar("{}/{}".format(scope_name, key), value, step) - - def dict_to_tb_figure(self, scope_name, figures, step): - for key, value in figures.items(): - self.writer.add_figure("{}/{}".format(scope_name, key), value, step) - - def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): - for key, value in audios.items(): - if value.dtype == "float16": - value = value.astype("float32") - try: - self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) - except RuntimeError: - traceback.print_exc() - - def train_step_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) - - def train_epoch_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step) - - def train_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step) - - def train_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate) - - def eval_stats(self, step, stats): - self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step) - - def eval_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step) - - def eval_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate) - - def test_audios(self, step, audios, sample_rate): - self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate) - - def test_figures(self, step, figures): - self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step) - - def add_text(self, title, text, step): - self.writer.add_text(title, text, step) - - def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201 - yield - - def flush(self): - self.writer.flush() - - def finish(self): - self.writer.close() diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py deleted file mode 100644 index 5fcab00f..00000000 --- a/TTS/utils/logging/wandb_logger.py +++ /dev/null @@ -1,111 +0,0 @@ -# pylint: disable=W0613 - -import traceback -from pathlib import Path - -try: - import wandb - from wandb import finish, init # pylint: disable=W0611 -except ImportError: - wandb = None - - -class WandbLogger: - def __init__(self, **kwargs): - - if not wandb: - raise Exception("install wandb using `pip install wandb` to use WandbLogger") - - self.run = None - self.run = wandb.init(**kwargs) if not wandb.run else wandb.run - self.model_name = self.run.config.model - self.log_dict = {} - - def model_weights(self, model): - layer_num = 1 - for name, param in model.named_parameters(): - if param.numel() == 1: - self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()}) - else: - self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}) - self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}) - self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}) - self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}) - self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param) - self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad) - layer_num += 1 - - def dict_to_scalar(self, scope_name, stats): - for key, value in stats.items(): - self.log_dict["{}/{}".format(scope_name, key)] = value - - def dict_to_figure(self, scope_name, figures): - for key, value in figures.items(): - self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value) - - def dict_to_audios(self, scope_name, audios, sample_rate): - for key, value in audios.items(): - if value.dtype == "float16": - value = value.astype("float32") - try: - self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate) - except RuntimeError: - traceback.print_exc() - - def log(self, log_dict, prefix="", flush=False): - for key, value in log_dict.items(): - self.log_dict[prefix + key] = value - if flush: # for cases where you don't want to accumulate data - self.flush() - - def train_step_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats) - - def train_epoch_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats) - - def train_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_TrainFigures", figures) - - def train_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate) - - def eval_stats(self, step, stats): - self.dict_to_scalar(f"{self.model_name}_EvalStats", stats) - - def eval_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_EvalFigures", figures) - - def eval_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate) - - def test_audios(self, step, audios, sample_rate): - self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate) - - def test_figures(self, step, figures): - self.dict_to_figure(f"{self.model_name}_TestFigures", figures) - - def add_text(self, title, text, step): - pass - - def flush(self): - if self.run: - wandb.log(self.log_dict) - self.log_dict = {} - - def finish(self): - if self.run: - self.run.finish() - - def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): - if not self.run: - return - name = "_".join([self.run.id, name]) - artifact = wandb.Artifact(name, type=artifact_type) - data_path = Path(file_or_dir) - if data_path.is_dir(): - artifact.add_dir(str(data_path)) - elif data_path.is_file(): - artifact.add_file(str(data_path)) - - self.run.log_artifact(artifact, aliases=aliases) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index fc45e7fa..d1abc907 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,10 +5,8 @@ import numpy as np import pysbd import torch -from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config +from TTS.config import load_config from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -110,25 +108,12 @@ class Synthesizer(object): use_cuda (bool): enable/disable CUDA use. """ # pylint: disable=global-statement - self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) + self.tts_model = setup_tts_model(config=self.tts_config) - speaker_manager = self._init_speaker_manager() - language_manager = self._init_language_manager() if not self.encoder_checkpoint: self._set_speaker_encoder_paths_from_tts_config() - speaker_manager = self._init_speaker_encoder(speaker_manager) - - if language_manager is not None: - self.tts_model = setup_tts_model( - config=self.tts_config, - speaker_manager=speaker_manager, - language_manager=language_manager, - ) - else: - self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -141,69 +126,6 @@ class Synthesizer(object): self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path - def _is_use_speaker_embedding(self): - """Check if the speaker embedding is used in the model""" - # we handle here the case that some models use model_args some don't - use_speaker_embedding = False - if hasattr(self.tts_config, "model_args"): - use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False) - use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False) - return use_speaker_embedding - - def _is_use_d_vector_file(self): - """Check if the d-vector file is used in the model""" - # we handle here the case that some models use model_args some don't - use_d_vector_file = False - if hasattr(self.tts_config, "model_args"): - config = self.tts_config.model_args - use_d_vector_file = config.get("use_d_vector_file", False) - config = self.tts_config - use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False) - return use_d_vector_file - - def _init_speaker_manager(self): - """Initialize the SpeakerManager""" - # setup if multi-speaker settings are in the global model config - speaker_manager = None - speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None) - if self._is_use_speaker_embedding(): - if self.tts_speakers_file: - speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) - elif speakers_file: - speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file) - - if self._is_use_d_vector_file(): - d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None) - if self.tts_speakers_file: - speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) - elif d_vector_file: - speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file) - return speaker_manager - - def _init_speaker_encoder(self, speaker_manager): - """Initialize the SpeakerEncoder""" - if self.encoder_checkpoint: - if speaker_manager is None: - speaker_manager = SpeakerManager( - encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config - ) - else: - speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) - return speaker_manager - - def _init_language_manager(self): - """Initialize the LanguageManager""" - # setup if multi-lingual settings are in the global model config - language_manager = None - if check_config_and_model_args(self.tts_config, "use_language_embedding", True): - if self.tts_languages_file: - language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) - elif self.tts_config.get("language_ids_file", None): - language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file) - else: - language_manager = LanguageManager(config=self.tts_config) - return language_manager - def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. @@ -243,7 +165,7 @@ class Synthesizer(object): path (str): output path to save the waveform. """ wav = np.array(wav) - self.ap.save_wav(wav, path, self.output_sample_rate) + self.tts_model.ap.save_wav(wav, path, self.output_sample_rate) def tts( self, @@ -331,12 +253,9 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, - ap=self.ap, speaker_id=speaker_id, language_id=language_id, - language_name=language_name, style_wav=style_wav, - enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, use_griffin_lim=use_gl, d_vector=speaker_embedding, ) @@ -344,14 +263,14 @@ class Synthesizer(object): mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() if not use_gl: # denormalize tts output based on tts audio config - mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch scale_factor = [ 1, - self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: print(" > interpolating tts model output.") @@ -369,7 +288,7 @@ class Synthesizer(object): # trim silence if self.tts_config.audio["do_trim_silence"] is True: - waveform = trim_silence(waveform, self.ap) + waveform = trim_silence(waveform, self.tts_model.ap) wavs += list(waveform) wavs += [0] * 10000 diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py deleted file mode 100644 index dabb33cd..00000000 --- a/TTS/utils/trainer_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -import importlib -import os -import re -from typing import Dict, List, Tuple -from urllib.parse import urlparse - -import fsspec -import torch - -from TTS.utils.io import load_fsspec -from TTS.utils.training import NoamLR - - -def is_apex_available(): - return importlib.util.find_spec("apex") is not None - - -def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]: - """Setup PyTorch environment for training. - - Args: - cudnn_enable (bool): Enable/disable CUDNN. - cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is - variable between batches. - use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise. - - Returns: - Tuple[bool, int]: is cuda on or off and number of GPUs in the environment. - """ - num_gpus = torch.cuda.device_count() - if num_gpus > 1 and not use_ddp: - raise RuntimeError( - f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." - ) - torch.backends.cudnn.enabled = cudnn_enable - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.manual_seed(54321) - use_cuda = torch.cuda.is_available() - print(" > Using CUDA: ", use_cuda) - print(" > Number of GPUs: ", num_gpus) - return use_cuda, num_gpus - - -def get_scheduler( - lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer -) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access - """Find, initialize and return a scheduler. - - Args: - lr_scheduler (str): Scheduler name. - lr_scheduler_params (Dict): Scheduler parameters. - optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler. - - Returns: - torch.optim.lr_scheduler._LRScheduler: Functional scheduler. - """ - if lr_scheduler is None: - return None - if lr_scheduler.lower() == "noamlr": - scheduler = NoamLR - else: - scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler) - return scheduler(optimizer, **lr_scheduler_params) - - -def get_optimizer( - optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None -) -> torch.optim.Optimizer: - """Find, initialize and return a optimizer. - - Args: - optimizer_name (str): Optimizer name. - optimizer_params (dict): Optimizer parameters. - lr (float): Initial learning rate. - model (torch.nn.Module): Model to pass to the optimizer. - - Returns: - torch.optim.Optimizer: Functional optimizer. - """ - if optimizer_name.lower() == "radam": - module = importlib.import_module("TTS.utils.radam") - optimizer = getattr(module, "RAdam") - else: - optimizer = getattr(torch.optim, optimizer_name) - if model is not None: - parameters = model.parameters() - return optimizer(parameters, lr=lr, **optimizer_params) - - -def get_last_checkpoint(path: str) -> Tuple[str, str]: - """Get latest checkpoint or/and best model in path. - - It is based on globbing for `*.pth.tar` and the RegEx - `(checkpoint|best_model)_([0-9]+)`. - - Args: - path: Path to files to be compared. - - Raises: - ValueError: If no checkpoint or best_model files are found. - - Returns: - Path to the last checkpoint - Path to best checkpoint - """ - fs = fsspec.get_mapper(path).fs - file_names = fs.glob(os.path.join(path, "*.pth.tar")) - scheme = urlparse(path).scheme - if scheme: # scheme is not preserved in fs.glob, add it back - file_names = [scheme + "://" + file_name for file_name in file_names] - last_models = {} - last_model_nums = {} - for key in ["checkpoint", "best_model"]: - last_model_num = None - last_model = None - # pass all the checkpoint files and find - # the one with the largest model number suffix. - for file_name in file_names: - match = re.search(f"{key}_([0-9]+)", file_name) - if match is not None: - model_num = int(match.groups()[0]) - if last_model_num is None or model_num > last_model_num: - last_model_num = model_num - last_model = file_name - - # if there is no checkpoint found above - # find the checkpoint with the latest - # modification date. - key_file_names = [fn for fn in file_names if key in fn] - if last_model is None and len(key_file_names) > 0: - last_model = max(key_file_names, key=os.path.getctime) - last_model_num = load_fsspec(last_model)["step"] - - if last_model is not None: - last_models[key] = last_model - last_model_nums[key] = last_model_num - - # check what models were found - if not last_models: - raise ValueError(f"No models found in continue path {path}!") - if "checkpoint" not in last_models: # no checkpoint just best model - last_models["checkpoint"] = last_models["best_model"] - elif "best_model" not in last_models: # no best model - # this shouldn't happen, but let's handle it just in case - last_models["best_model"] = last_models["checkpoint"] - # finally check if last best model is more recent than checkpoint - elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: - last_models["checkpoint"] = last_models["best_model"] - - return last_models["checkpoint"], last_models["best_model"] diff --git a/TTS/utils/training.py b/TTS/utils/training.py index aa5651c5..b51f55e9 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -30,20 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): return grad_norm, skip_flag -# pylint: disable=protected-access -class NoamLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): - self.warmup_steps = float(warmup_steps) - super().__init__(optimizer, last_epoch) - - def get_lr(self): - step = max(self.last_epoch, 1) - return [ - base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5) - for base_lr in self.base_lrs - ] - - def gradual_training_scheduler(global_step, config): """Setup the gradual training schedule wrt number of active GPUs""" @@ -56,31 +42,3 @@ def gradual_training_scheduler(global_step, config): if global_step * num_gpus >= values[0]: new_values = values return new_values[1], new_values[2] - - -def lr_decay(init_lr, global_step, warmup_steps): - r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py - It is only being used by the Speaker Encoder trainer.""" - warmup_steps = float(warmup_steps) - step = global_step + 1.0 - lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) - return lr - - -# pylint: disable=dangerous-default-value -def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): - """ - Skip biases, BatchNorm parameters, rnns. - and attention projection layer v - """ - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): - no_decay.append(param) - else: - decay.append(param) - return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py index a89b1f3f..7845dd6b 100644 --- a/TTS/vocoder/configs/parallel_wavegan_config.py +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -70,11 +70,11 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): lr_scheduler_gen (torch.optim.Scheduler): Learning rate scheduler for the generator. Defaults to `ExponentialLR`. lr_scheduler_gen_params (dict): - Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. lr_scheduler_disc (torch.optim.Scheduler): Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. lr_scheduler_dict_params (dict): - Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. """ model: str = "parallel_wavegan" @@ -124,7 +124,10 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): lr_disc: float = 0.0002 # Initial learning rate. optimizer: str = "AdamW" optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) - lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html - lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) - lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html - lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1} + ) + scheduler_after_epoch: bool = False diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py index 9ff6f790..a558cfca 100644 --- a/TTS/vocoder/configs/shared_configs.py +++ b/TTS/vocoder/configs/shared_configs.py @@ -99,7 +99,7 @@ class BaseGANVocoderConfig(BaseVocoderConfig): "mel_fmax": None, }` target_loss (str): - Target loss name that defines the quality of the model. Defaults to `avg_G_loss`. + Target loss name that defines the quality of the model. Defaults to `G_avg_loss`. grad_clip (list): A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping. Defaults to [5, 5]. diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index d8cc350a..0f69b812 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -33,8 +33,8 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): np.save(quant_path, quant) -def find_wav_files(data_path): - wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) return wav_paths @@ -43,8 +43,9 @@ def find_feat_files(data_path): return feat_paths -def load_wav_data(data_path, eval_split_size): - wav_paths = find_wav_files(data_path) +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." np.random.seed(0) np.random.shuffle(wav_paths) return wav_paths[:eval_split_size], wav_paths[eval_split_size:] diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index d648b68c..2c771cf0 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -111,7 +111,7 @@ class WaveRNNDataset(Dataset): elif isinstance(self.mode, int): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) - x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 + x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/lvc_block.py b/TTS/vocoder/layers/lvc_block.py index 0e29ee3c..8913a113 100644 --- a/TTS/vocoder/layers/lvc_block.py +++ b/TTS/vocoder/layers/lvc_block.py @@ -126,9 +126,9 @@ class LVCBlock(torch.nn.Module): ) for i in range(conv_layers): - padding = (3 ** i) * int((conv_kernel_size - 1) / 2) + padding = (3**i) * int((conv_kernel_size - 1) / 2) conv = torch.nn.Conv1d( - in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i + in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i ) self.convs.append(conv) diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py index 7fd999d9..4bb328e9 100644 --- a/TTS/vocoder/layers/melgan.py +++ b/TTS/vocoder/layers/melgan.py @@ -12,7 +12,7 @@ class ResidualStack(nn.Module): self.blocks = nn.ModuleList() for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size ** idx + layer_dilation = layer_kernel_size**idx layer_padding = base_padding * layer_dilation self.blocks += [ nn.Sequential( diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py index 889e8aa6..51142e5e 100644 --- a/TTS/vocoder/layers/parallel_wavegan.py +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -72,6 +72,6 @@ class ResidualBlock(torch.nn.Module): s = self.conv1x1_skip(x) # for residual connection - x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) + x = (self.conv1x1_out(x) + residual) * (0.5**2) return x, s diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index a70ebe40..65901617 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -28,8 +28,7 @@ def setup_model(config: Coqpit): except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e print(" > Vocoder Model: {}".format(config.model)) - model = MyModel(config) - return model + return MyModel.init_from_config(config) def setup_generator(c): diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index 9d6ef26f..01a7ff68 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -1,11 +1,11 @@ from coqpit import Coqpit -from TTS.model import BaseModel +from TTS.model import BaseTrainerModel # pylint: skip-file -class BaseVocoder(BaseModel): +class BaseVocoder(BaseTrainerModel): """Base `vocoder` class. Every new `vocoder` model must inherit this. It defines `vocoder` specific functions on top of `Model`. @@ -19,7 +19,8 @@ class BaseVocoder(BaseModel): """ def __init__(self, config): - super().__init__(config) + super().__init__() + self._set_model_args(config) def _set_model_args(self, config: Coqpit): """Setup model args based on the config type. diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 76fee505..3b8a3fbe 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -7,10 +7,10 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec -from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.models import setup_discriminator, setup_generator @@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results class GAN(BaseVocoder): - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. It also helps mixing and matching different generator and disciminator networks easily. @@ -28,6 +28,7 @@ class GAN(BaseVocoder): Args: config (Coqpit): Model configuration. + ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. Examples: Initializing the GAN model with HifiGAN generator and discriminator. @@ -41,6 +42,7 @@ class GAN(BaseVocoder): self.model_d = setup_discriminator(config) self.train_disc = False # if False, train only the generator. self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + self.ap = ap def forward(self, x: torch.Tensor) -> torch.Tensor: """Run the generator's forward pass. @@ -78,8 +80,8 @@ class GAN(BaseVocoder): Returns: Tuple[Dict, Dict]: model outputs and the computed loss values. """ - outputs = None - loss_dict = None + outputs = {} + loss_dict = {} x = batch["input"] y = batch["waveform"] @@ -201,10 +203,9 @@ class GAN(BaseVocoder): self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for training.""" - ap = assets["audio_processor"] - figures, audios = self._log("eval", ap, batch, outputs) + figures, audios = self._log("eval", self.ap, batch, outputs) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: @@ -215,10 +216,9 @@ class GAN(BaseVocoder): self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for evaluation.""" - ap = assets["audio_processor"] - figures, audios = self._log("eval", ap, batch, outputs) + figures, audios = self._log("eval", self.ap, batch, outputs) logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, ap.sample_rate) + logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( self, @@ -306,15 +306,15 @@ class GAN(BaseVocoder): x, y = batch return {"input": x, "waveform": y} - def get_data_loader( # pylint: disable=no-self-use + def get_data_loader( # pylint: disable=no-self-use, unused-argument self, config: Coqpit, assets: Dict, is_eval: True, - data_items: List, + samples: List, verbose: bool, num_gpus: int, - rank: int = 0, # pylint: disable=unused-argument + rank: int = None, # pylint: disable=unused-argument ): """Initiate and return the GAN dataloader. @@ -322,19 +322,19 @@ class GAN(BaseVocoder): config (Coqpit): Model config. ap (AudioProcessor): Audio processor. is_eval (True): Set the dataloader for evaluation if true. - data_items (List): Data samples. + samples (List): Data samples. verbose (bool): Log information if true. num_gpus (int): Number of GPUs in use. + rank (int): Rank of the current GPU. Defaults to None. Returns: DataLoader: Torch dataloader. """ - ap = assets["audio_processor"] dataset = GANDataset( - ap=ap, - items=data_items, + ap=self.ap, + items=samples, seq_len=config.seq_len, - hop_len=ap.hop_length, + hop_len=self.ap.hop_length, pad_short=config.pad_short, conv_pad=config.conv_pad, return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, @@ -360,3 +360,8 @@ class GAN(BaseVocoder): def get_criterion(self): """Return criterions for the optimizers""" return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)] + + @staticmethod + def init_from_config(config: Coqpit, verbose=True) -> "GAN": + ap = AudioProcessor.init_from_config(config, verbose=verbose) + return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 4ce743b3..fc15f3af 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -207,7 +207,7 @@ class HifiganGenerator(torch.nn.Module): self.ups.append( weight_norm( ConvTranspose1d( - upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index e60baa9d..80b47870 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -36,7 +36,7 @@ class MelganGenerator(nn.Module): # upsampling layers and residual stacks for idx, upsample_factor in enumerate(upsample_factors): - layer_in_channels = base_channels // (2 ** idx) + layer_in_channels = base_channels // (2**idx) layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index 9cc1061c..adf1bdae 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -35,7 +35,7 @@ class ParallelWaveganDiscriminator(nn.Module): if i == 0: dilation = 1 else: - dilation = i if dilation_factor == 1 else dilation_factor ** i + dilation = i if dilation_factor == 1 else dilation_factor**i conv_in_channels = conv_channels padding = (kernel_size - 1) // 2 * dilation conv_layer = [ diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index b8e78d03..ee9d8ad5 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -142,7 +142,7 @@ class ParallelWaveganGenerator(torch.nn.Module): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 8a66c537..2ee28c7b 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -130,7 +130,7 @@ class UnivnetGenerator(torch.nn.Module): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index ed4f4b37..c4968f1f 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -8,9 +8,9 @@ from torch import nn from torch.nn.utils import weight_norm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.io import load_fsspec -from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.models.base_vocoder import BaseVocoder @@ -153,7 +153,7 @@ class Wavegrad(BaseVocoder): noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) noise_scale = noise_scale.unsqueeze(1) noise = torch.randn_like(y_0) - noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] def compute_noise_level(self, beta): @@ -161,8 +161,8 @@ class Wavegrad(BaseVocoder): self.num_steps = len(beta) alpha = 1 - beta alpha_hat = np.cumprod(alpha) - noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0) - noise_level = alpha_hat ** 0.5 + noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) + noise_level = alpha_hat**0.5 # pylint: disable=not-callable self.beta = torch.tensor(beta.astype(np.float32)) @@ -170,7 +170,7 @@ class Wavegrad(BaseVocoder): self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) self.noise_level = torch.tensor(noise_level.astype(np.float32)) - self.c1 = 1 / self.alpha ** 0.5 + self.c1 = 1 / self.alpha**0.5 self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 @@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder): ) -> None: pass - def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) for sample in samples: x = sample[0] x = x[None, :, :].to(next(self.parameters()).device) @@ -306,13 +307,11 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int - ): + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, - items=data_items, + items=samples, seq_len=self.config.seq_len, hop_len=ap.hop_length, pad_short=self.config.pad_short, @@ -339,3 +338,7 @@ class Wavegrad(BaseVocoder): noise_schedule = self.config["train_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 1977efb6..6686db45 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -225,7 +225,7 @@ class Wavernn(BaseVocoder): super().__init__(config) if isinstance(self.args.mode, int): - self.n_classes = 2 ** self.args.mode + self.n_classes = 2**self.args.mode elif self.args.mode == "mold": self.n_classes = 3 * 10 elif self.args.mode == "gauss": @@ -568,12 +568,13 @@ class Wavernn(BaseVocoder): return self.train_step(batch, criterion) @torch.no_grad() - def test_run( - self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: ap = assets["audio_processor"] figures = {} audios = {} + samples = test_loader.dataset.load_test_samples(1) for idx, sample in enumerate(samples): x = torch.FloatTensor(sample[0]) x = x.to(next(self.parameters()).device) @@ -600,14 +601,14 @@ class Wavernn(BaseVocoder): config: Coqpit, assets: Dict, is_eval: True, - data_items: List, + samples: List, verbose: bool, num_gpus: int, ): ap = assets["audio_processor"] dataset = WaveRNNDataset( ap=ap, - items=data_items, + items=samples, seq_len=config.seq_len, hop_len=ap.hop_length, pad=config.model_args.pad, @@ -631,3 +632,7 @@ class Wavernn(BaseVocoder): def get_criterion(self): # define train functions return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config) diff --git a/TTS/vocoder/tf/layers/melgan.py b/TTS/vocoder/tf/layers/melgan.py deleted file mode 100644 index 90bce6f1..00000000 --- a/TTS/vocoder/tf/layers/melgan.py +++ /dev/null @@ -1,54 +0,0 @@ -import tensorflow as tf - - -class ReflectionPad1d(tf.keras.layers.Layer): - def __init__(self, padding): - super().__init__() - self.padding = padding - - def call(self, x): - return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT") - - -class ResidualStack(tf.keras.layers.Layer): - def __init__(self, channels, num_res_blocks, kernel_size, name): - super().__init__(name=name) - - assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." - base_padding = (kernel_size - 1) // 2 - - self.blocks = [] - num_layers = 2 - for idx in range(num_res_blocks): - layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size ** idx - layer_padding = base_padding * layer_dilation - block = [ - tf.keras.layers.LeakyReLU(0.2), - ReflectionPad1d(layer_padding), - tf.keras.layers.Conv2D( - filters=channels, - kernel_size=(kernel_size, 1), - dilation_rate=(layer_dilation, 1), - use_bias=True, - padding="valid", - name=f"blocks.{idx}.{num_layers}", - ), - tf.keras.layers.LeakyReLU(0.2), - tf.keras.layers.Conv2D( - filters=channels, kernel_size=(1, 1), use_bias=True, name=f"blocks.{idx}.{num_layers + 2}" - ), - ] - self.blocks.append(block) - self.shortcuts = [ - tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f"shortcuts.{i}") - for i in range(num_res_blocks) - ] - - def call(self, x): - for block, shortcut in zip(self.blocks, self.shortcuts): - res = shortcut(x) - for layer in block: - x = layer(x) - x += res - return x diff --git a/TTS/vocoder/tf/layers/pqmf.py b/TTS/vocoder/tf/layers/pqmf.py deleted file mode 100644 index 042f2f08..00000000 --- a/TTS/vocoder/tf/layers/pqmf.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -import tensorflow as tf -from scipy import signal as sig - - -class PQMF(tf.keras.layers.Layer): - def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): - super().__init__() - # define filter coefficient - self.N = N - self.taps = taps - self.cutoff = cutoff - self.beta = beta - - QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) - H = np.zeros((N, len(QMF))) - G = np.zeros((N, len(QMF))) - for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - phase = (-1) ** k * np.pi / 4 - H[k] = 2 * QMF * np.cos(constant_factor + phase) - - G[k] = 2 * QMF * np.cos(constant_factor - phase) - - # [N, 1, taps + 1] == [filter_width, in_channels, out_channels] - self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype("float32") - self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype("float32") - - # filter for downsampling & upsampling - updown_filter = np.zeros((N, N, N), dtype=np.float32) - for k in range(N): - updown_filter[0, k, k] = 1.0 - self.updown_filter = updown_filter.astype(np.float32) - - def analysis(self, x): - """ - x : :math:`[B, 1, T]` - """ - x = tf.transpose(x, perm=[0, 2, 1]) - x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.H, stride=1, padding="VALID") - x = tf.nn.conv1d(x, self.updown_filter, stride=self.N, padding="VALID") - x = tf.transpose(x, perm=[0, 2, 1]) - return x - - def synthesis(self, x): - """ - x : B x D x T - """ - x = tf.transpose(x, perm=[0, 2, 1]) - x = tf.nn.conv1d_transpose( - x, - self.updown_filter * self.N, - strides=self.N, - output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, self.N), - ) - x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID") - x = tf.transpose(x, perm=[0, 2, 1]) - return x diff --git a/TTS/vocoder/tf/models/melgan_generator.py b/TTS/vocoder/tf/models/melgan_generator.py deleted file mode 100644 index 09ee9530..00000000 --- a/TTS/vocoder/tf/models/melgan_generator.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -import os - -import tensorflow as tf - -from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL -logging.getLogger("tensorflow").setLevel(logging.FATAL) - -from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack - - -# pylint: disable=too-many-ancestors -# pylint: disable=abstract-method -class MelganGenerator(tf.keras.models.Model): - """Melgan Generator TF implementation dedicated for inference with no - weight norm""" - - def __init__( - self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3, - ): - super().__init__() - - self.in_channels = in_channels - - # assert model parameters - assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." - - # setup additional model parameters - base_padding = (proj_kernel - 1) // 2 - act_slope = 0.2 - self.inference_padding = 2 - - # initial layer - self.initial_layer = [ - ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D( - filters=base_channels, kernel_size=(proj_kernel, 1), strides=1, padding="valid", use_bias=True, name="1" - ), - ] - num_layers = 3 # count number of layers for layer naming - - # upsampling layers and residual stacks - self.upsample_layers = [] - for idx, upsample_factor in enumerate(upsample_factors): - layer_out_channels = base_channels // (2 ** (idx + 1)) - layer_filter_size = upsample_factor * 2 - layer_stride = upsample_factor - # layer_output_padding = upsample_factor % 2 - self.upsample_layers += [ - tf.keras.layers.LeakyReLU(act_slope), - tf.keras.layers.Conv2DTranspose( - filters=layer_out_channels, - kernel_size=(layer_filter_size, 1), - strides=(layer_stride, 1), - padding="same", - # output_padding=layer_output_padding, - use_bias=True, - name=f"{num_layers}", - ), - ResidualStack( - channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel, - name=f"layers.{num_layers + 1}", - ), - ] - num_layers += num_res_blocks - 1 - - self.upsample_layers += [tf.keras.layers.LeakyReLU(act_slope)] - - # final layer - self.final_layers = [ - ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D( - filters=out_channels, kernel_size=(proj_kernel, 1), use_bias=True, name=f"layers.{num_layers + 1}" - ), - tf.keras.layers.Activation("tanh"), - ] - - # self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers") - self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers - - @tf.function(experimental_relax_shapes=True) - def call(self, c, training=False): - """ - c : :math:`[B, C, T]` - """ - if training: - raise NotImplementedError() - return self.inference(c) - - def inference(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - # o = self.model_layers(c) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - return o[:, :, 0, :] - - def build_inference(self): - x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32) - self(x, training=False) - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, None, None], dtype=tf.float32), - ], - ) - def inference_tflite(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - # o = self.model_layers(c) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - return o[:, :, 0, :] diff --git a/TTS/vocoder/tf/models/multiband_melgan_generator.py b/TTS/vocoder/tf/models/multiband_melgan_generator.py deleted file mode 100644 index 24d899b2..00000000 --- a/TTS/vocoder/tf/models/multiband_melgan_generator.py +++ /dev/null @@ -1,65 +0,0 @@ -import tensorflow as tf - -from TTS.vocoder.tf.layers.pqmf import PQMF -from TTS.vocoder.tf.models.melgan_generator import MelganGenerator - - -# pylint: disable=too-many-ancestors -# pylint: disable=abstract-method -class MultibandMelganGenerator(MelganGenerator): - def __init__( - self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks, - ) - self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) - - def pqmf_analysis(self, x): - return self.pqmf_layer.analysis(x) - - def pqmf_synthesis(self, x): - return self.pqmf_layer.synthesis(x) - - def inference(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - o = self.pqmf_layer.synthesis(o[:, :, 0, :]) - return o - - @tf.function( - experimental_relax_shapes=True, - input_signature=[ - tf.TensorSpec([1, 80, None], dtype=tf.float32), - ], - ) - def inference_tflite(self, c): - c = tf.transpose(c, perm=[0, 2, 1]) - c = tf.expand_dims(c, 2) - # FIXME: TF had no replicate padding as in Torch - # c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") - o = c - for layer in self.model_layers: - o = layer(o) - o = tf.transpose(o, perm=[0, 3, 2, 1]) - o = self.pqmf_layer.synthesis(o[:, :, 0, :]) - return o diff --git a/TTS/vocoder/tf/utils/__init__.py b/TTS/vocoder/tf/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py deleted file mode 100644 index 453d8b78..00000000 --- a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np -import tensorflow as tf - - -def compare_torch_tf(torch_tensor, tf_tensor): - """Compute the average absolute difference b/w torch and tf tensors""" - return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() - - -def convert_tf_name(tf_name): - """Convert certain patterns in TF layer names to Torch patterns""" - tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(":0", "") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") - tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") - tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") - tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") - tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") - tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") - tf_name_tmp = tf_name_tmp.replace("/", ".") - return tf_name_tmp - - -def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): - """Transfer weigths from torch state_dict to TF variables""" - print(" > Passing weights from Torch to TF ...") - for tf_var in tf_vars: - torch_var_name = var_map_dict[tf_var.name] - print(f" | > {tf_var.name} <-- {torch_var_name}") - # if tuple, it is a bias variable - if "kernel" in tf_var.name: - torch_weight = state_dict[torch_var_name] - numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :] - if "bias" in tf_var.name: - torch_weight = state_dict[torch_var_name] - numpy_weight = torch_weight - assert np.all( - tf_var.shape == numpy_weight.shape - ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" - tf.keras.backend.set_value(tf_var, numpy_weight) - return tf_vars - - -def load_tf_vars(model_tf, tf_vars): - for tf_var in tf_vars: - model_tf.get_layer(tf_var.name).set_weights(tf_var) - return model_tf diff --git a/TTS/vocoder/tf/utils/generic_utils.py b/TTS/vocoder/tf/utils/generic_utils.py deleted file mode 100644 index 94364ab4..00000000 --- a/TTS/vocoder/tf/utils/generic_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -import importlib -import re - - -def to_camel(text): - text = text.capitalize() - return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) - - -def setup_generator(c): - print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module("TTS.vocoder.tf.models." + c.generator_model.lower()) - MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model in "melgan_generator": - model = MyModel( - in_channels=c.audio["num_mels"], - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=c.generator_model_params["upsample_factors"], - res_kernel=3, - num_res_blocks=c.generator_model_params["num_res_blocks"], - ) - if c.generator_model in "melgan_fb_generator": - pass - if c.generator_model in "multiband_melgan_generator": - model = MyModel( - in_channels=c.audio["num_mels"], - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=c.generator_model_params["upsample_factors"], - res_kernel=3, - num_res_blocks=c.generator_model_params["num_res_blocks"], - ) - return model diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py deleted file mode 100644 index 3de8adab..00000000 --- a/TTS/vocoder/tf/utils/io.py +++ /dev/null @@ -1,31 +0,0 @@ -import datetime -import pickle - -import fsspec -import tensorflow as tf - - -def save_checkpoint(model, current_step, epoch, output_path, **kwargs): - """Save TF Vocoder model""" - state = { - "model": model.weights, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - state.update(kwargs) - with fsspec.open(output_path, "wb") as f: - pickle.dump(state, f) - - -def load_checkpoint(model, checkpoint_path): - """Load TF Vocoder model""" - with fsspec.open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} - tf_vars = model.weights - for tf_var in tf_vars: - layer_name = tf_var.name - chkp_var_value = chkp_var_dict[layer_name] - tf.keras.backend.set_value(tf_var, chkp_var_value) - return model diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py deleted file mode 100644 index 876739fd..00000000 --- a/TTS/vocoder/tf/utils/tflite.py +++ /dev/null @@ -1,27 +0,0 @@ -import fsspec -import tensorflow as tf - - -def convert_melgan_to_tflite(model, output_path=None, experimental_converter=True): - """Convert Tensorflow MelGAN model to TFLite. Save a binary file if output_path is - provided, else return TFLite model.""" - - concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) - converter.experimental_new_converter = experimental_converter - converter.optimizations = [] - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite_model = converter.convert() - print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") - if output_path is not None: - # same model binary if outputpath is provided - with fsspec.open(output_path, "wb") as f: - f.write(tflite_model) - return None - return tflite_model - - -def load_tflite_model(tflite_path): - tflite_model = tf.lite.Interpreter(model_path=tflite_path) - tflite_model.allocate_tensors() - return tflite_model diff --git a/docs/source/_templates/page.html b/docs/source/_templates/page.html new file mode 100644 index 00000000..633b7661 --- /dev/null +++ b/docs/source/_templates/page.html @@ -0,0 +1,23 @@ + +{% extends "!page.html" %} +{% block scripts %} + {{ super() }} + + + + + + + +{% endblock %} diff --git a/docs/source/converting_torch_to_tf.md b/docs/source/converting_torch_to_tf.md deleted file mode 100644 index 20a0be6b..00000000 --- a/docs/source/converting_torch_to_tf.md +++ /dev/null @@ -1,21 +0,0 @@ -# Converting Torch to TF 2 - -Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2. - -You can also export TF 2 models to TFLite for even faster inference. - -## How to convert from Torch to TF 2.0 -Make sure you installed Tensorflow v2.2. It is not installed by default by :frog: TTS. - -All the TF related code stays under ```tf``` folder. - -To convert a **compatible** Torch model, run the following command with the right arguments: - -```bash -python TTS/bin/convert_tacotron2_torch_to_tf.py\ - --torch_model_path /path/to/torch/model.pth.tar \ - --config_path /path/to/model/config.json\ - --output_path /path/to/output/tf/model -``` - -This will create a TF model file. Notice that our model format is not compatible with the official TF checkpoints. We created our custom format to match Torch checkpoints we use. Therefore, use the ```load_checkpoint``` and ```save_checkpoint``` functions provided under ```TTS.tf.generic_utils```. diff --git a/docs/source/finetuning.md b/docs/source/finetuning.md index 42b9e518..7d7ef1cb 100644 --- a/docs/source/finetuning.md +++ b/docs/source/finetuning.md @@ -9,33 +9,33 @@ them and fine-tune it for your own dataset. This will help you in two main ways: 1. Faster learning Since a pre-trained model has already learned features that are relevant for the task, it will converge faster on - a new dataset. This will reduce the cost of training and let you experient faster. + a new dataset. This will reduce the cost of training and let you experiment faster. 2. Better resutls with small datasets Deep learning models are data hungry and they give better performance with more data. However, it is not always - possible to have this abondance, especially in domain. For instance, LJSpeech dataset, that we released most of - our English models with, is almost 24 hours long. And it requires for someone to collect thid amount of data with - a help of a voice talent takes weeks. + possible to have this abundance, especially in specific domains. For instance, the LJSpeech dataset, that we released most of + our English models with, is almost 24 hours long. It takes weeks to record this amount of data with + the help of a voice actor. - Fine-tuning cames to rescue in this case. You can take one of our pre-trained models and fine-tune it for your own - speech dataset and achive reasonable results with only a couple of hours in the worse case. + Fine-tuning comes to the rescue in this case. You can take one of our pre-trained models and fine-tune it on your own + speech dataset and achive reasonable results with only a couple of hours of data. - However, note that, fine-tuning does not promise great results. The model performance is still depends on the + However, note that, fine-tuning does not ensure great results. The model performance is still depends on the {ref}`dataset quality ` and the hyper-parameters you choose for fine-tuning. Therefore, - it still demands a bit of tinkering. + it still takes a bit of tinkering. ## Steps to fine-tune a 🐸 TTS model 1. Setup your dataset. - You need to format your target dataset in a certain way so that 🐸TTS data loader would be able to load it for the + You need to format your target dataset in a certain way so that 🐸TTS data loader will be able to load it for the training. Please see {ref}`this page ` for more information about formatting. 2. Choose the model you want to fine-tune. - You can list the availabe models on terminal as + You can list the availabe models in the command line with ```bash tts --list_models @@ -43,15 +43,15 @@ them and fine-tune it for your own dataset. This will help you in two main ways: The command above lists the the models in a naming format as ```///```. - Or you can manually check `.model.json` file in the project directory. + Or you can manually check the `.model.json` file in the project directory. You should choose the model based on your requirements. Some models are fast and some are better in speech quality. - One lazy way to check a model is running the model on the hardware you want to use and see how it works. For + One lazy way to test a model is running the model on the hardware you want to use and see how it works. For simple testing, you can use the `tts` command on the terminal. For more info see {ref}`here `. 3. Download the model. - You can download the model by `tts` command. If you run `tts` with a particular model, it will download automatically + You can download the model by using the `tts` command. If you run `tts` with a particular model, it will download it automatically and the model path will be printed on the terminal. ```bash @@ -78,12 +78,12 @@ them and fine-tune it for your own dataset. This will help you in two main ways: - `run_name` field: This is the name of the run. This is used to name the output directory and the entry in the logging dashboard. - `output_path` field: This is the path where the fine-tuned model is saved. - - `lr` field: You may need to use a smaller learning rate for fine-tuning not to impair the features learned by the + - `lr` field: You may need to use a smaller learning rate for fine-tuning to not lose the features learned by the pre-trained model with big update steps. - `audio` fields: Different datasets have different audio characteristics. You must check the current audio parameters and make sure that the values reflect your dataset. For instance, your dataset might have a different audio sampling rate. - Apart from these above, you should check the whole configuration file and make sure that the values are correct for + Apart from the parameters above, you should check the whole configuration file and make sure that the values are correct for your dataset and training. 5. Start fine-tuning. @@ -112,4 +112,3 @@ them and fine-tune it for your own dataset. This will help you in two main ways: --coqpit.lr 0.00001 ``` - diff --git a/docs/source/formatting_your_dataset.md b/docs/source/formatting_your_dataset.md index 3db38af0..294d2b29 100644 --- a/docs/source/formatting_your_dataset.md +++ b/docs/source/formatting_your_dataset.md @@ -19,15 +19,15 @@ Let's assume you created the audio clips and their transcription. You can collec You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text. -We recommend the following format delimited by `||`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc. +We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc. ``` # metadata.txt -audio1||This is my sentence. -audio2||This is maybe my sentence. -audio3||This is certainly my sentence. -audio4||Let this be your sentence. +audio1|This is my sentence. +audio2|This is maybe my sentence. +audio3|This is certainly my sentence. +audio4|Let this be your sentence. ... ``` @@ -58,23 +58,68 @@ If you use a different dataset format then the LJSpeech or the other public data If your dataset is in a new language or it needs special normalization steps, then you need a new `text_cleaner`. -What you get out of a `formatter` is a `List[List[]]` in the following format. +What you get out of a `formatter` is a `List[Dict]` in the following format. ``` >>> formatter(metafile_path) -[["audio1.wav", "This is my sentence.", "MyDataset"], -["audio1.wav", "This is maybe a sentence.", "MyDataset"], -... +[ + {"audio_file":"audio1.wav", "text":"This is my sentence.", "speaker_name":"MyDataset", "language": "lang_code"}, + {"audio_file":"audio1.wav", "text":"This is maybe a sentence.", "speaker_name":"MyDataset", "language": "lang_code"}, + ... ] ``` -Each sub-list is parsed as ```["", "", "]```. +Each sub-list is parsed as ```{"", "", "]```. `````` is the dataset name for single speaker datasets and it is mainly used in the multi-speaker models to map the speaker of the each sample. But for now, we only focus on single speaker datasets. -The purpose of a `formatter` is to parse your metafile and load the audio file paths and transcriptions. Then, its output passes to a `Dataset` object. It computes features from the audio signals, calls text normalization routines, and converts raw text to +The purpose of a `formatter` is to parse your manifest file and load the audio file paths and transcriptions. +Then, the output is passed to the `Dataset`. It computes features from the audio signals, calls text normalization routines, and converts raw text to phonemes if needed. +## Loading your dataset + +Load one of the dataset supported by 🐸TTS. + +```python +from TTS.tts.configs.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples + + +# dataset config for one of the pre-defined datasets +dataset_config = BaseDatasetConfig( + name="vctk", meta_file_train="", language="en-us", path="dataset-path") +) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) +``` + +Load a custom dataset with a custom formatter. + +```python +from TTS.tts.datasets import load_tts_samples + + +# custom formatter implementation +def formatter(root_path, manifest_file, **kwargs): # pylint: disable=unused-argument + """Assumes each line as ```|``` + """ + txt_file = os.path.join(root_path, manifest_file) + items = [] + speaker_name = "my_speaker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0]) + text = cols[1] + items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name}) + return items + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True, formatter=formatter) +``` + See `TTS.tts.datasets.TTSDataset`, a generic `Dataset` implementation for the `tts` models. See `TTS.vocoder.datasets.*`, for different `Dataset` implementations for the `vocoder` models. diff --git a/docs/source/index.md b/docs/source/index.md index 756cea8e..9dc5bfce 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -27,7 +27,6 @@ formatting_your_dataset what_makes_a_good_dataset tts_datasets - converting_torch_to_tf .. toctree:: :maxdepth: 2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 6532ee8e..0122271d 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -12,12 +12,6 @@ You can install from PyPI as follows: pip install TTS # from PyPI ``` -By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra. - -```bash -pip install TTS[tf] -``` - Or install from Github: ```bash diff --git a/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb b/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb deleted file mode 100644 index 8a25132c..00000000 --- a/notebooks/Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb +++ /dev/null @@ -1,425 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "6LWsNd3_M3MP" - }, - "source": [ - "# Converting Pytorch models to Tensorflow and TFLite by CoquiTTS" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FAqrSIWgLyP0" - }, - "source": [ - "This is a tutorial demonstrating Coqui TTS capabilities to convert \n", - "trained PyTorch models to Tensorflow and Tflite.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MBJjGYnoEo4v" - }, - "source": [ - "# Installation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Ku-dA4DKoeXk" - }, - "source": [ - "### Download TF Models and configs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 162 - }, - "colab_type": "code", - "id": "jGIgnWhGsxU1", - "outputId": "b461952f-8507-4dd2-af06-4e6b8692765d", - "tags": [] - }, - "outputs": [], - "source": [ - "!gdown --id 1dntzjWFg7ufWaTaFy80nRz-Tu02xWZos -O data/tts_model.pth.tar\n", - "!gdown --id 18CQ6G6tBEOfvCHlPqP8EBI4xWbrr9dBc -O data/config.json" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 235 - }, - "colab_type": "code", - "id": "4dnpE0-kvTsu", - "outputId": "f67c3138-bda0-4b3e-ffcc-647f9feec23e", - "tags": [] - }, - "outputs": [], - "source": [ - "!gdown --id 1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K -O data/vocoder_model.pth.tar\n", - "!gdown --id 1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu -O data/config_vocoder.json\n", - "!gdown --id 11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU -O data/scale_stats.npy" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3IGvvCRMEwqn" - }, - "source": [ - "# Model Conversion PyTorch -> TF -> TFLite" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "tLhz8SAf8Pgp" - }, - "source": [ - "## Converting PyTorch to Tensorflow\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "colab_type": "code", - "id": "Xsrvr_WQ8Ib5", - "outputId": "dae96616-e5f7-41b6-cdb9-5026cfcd3214", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert TTS model to Tensorflow\n", - "!python ../TTS/bin/convert_tacotron2_torch_to_tf.py --config_path data/config.json --torch_model_path data/tts_model.pth.tar --output_path data/tts_model_tf.pkl" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "colab_type": "code", - "id": "VJ4NA5If9ljv", - "outputId": "1520dca8-1db8-4e07-bc0c-b1d5941c775e", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert Vocoder model to Tensorflow\n", - "!python ../TTS/bin/convert_melgan_torch_to_tf.py --config_path data/config_vocoder.json --torch_model_path data/vocoder_model.pth.tar --output_path data/vocoder_model_tf.pkl" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7d5vTkBZ-BYQ" - }, - "source": [ - "## Converting Tensorflow to TFLite" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 927 - }, - "colab_type": "code", - "id": "33hTfpuU99cg", - "outputId": "8a0e5be1-23a2-4128-ee37-8232adcb8ff0", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert TTS model to TFLite\n", - "!python ../TTS/bin/convert_tacotron2_tflite.py --config_path data/config.json --tf_model data/tts_model_tf.pkl --output_path data/tts_model.tflite" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 364 - }, - "colab_type": "code", - "id": "e00Hm75Y-wZ2", - "outputId": "42381b05-3c9d-44f0-dac7-d81efd95eadf", - "tags": [] - }, - "outputs": [], - "source": [ - "# convert Vocoder model to TFLite\n", - "!python ../TTS/bin/convert_melgan_tflite.py --config_path data/config_vocoder.json --tf_model data/vocoder_model_tf.pkl --output_path data/vocoder_model.tflite" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Zlgi8fPdpRF0" - }, - "source": [ - "# Run Inference with TFLite " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "f-Yc42nQZG5A" - }, - "outputs": [], - "source": [ - "def run_vocoder(mel_spec):\n", - " vocoder_inputs = mel_spec[None, :, :]\n", - " # get input and output details\n", - " input_details = vocoder_model.get_input_details()\n", - " # reshape input tensor for the new input shape\n", - " vocoder_model.resize_tensor_input(input_details[0]['index'], vocoder_inputs.shape)\n", - " vocoder_model.allocate_tensors()\n", - " detail = input_details[0]\n", - " vocoder_model.set_tensor(detail['index'], vocoder_inputs)\n", - " # run the model\n", - " vocoder_model.invoke()\n", - " # collect outputs\n", - " output_details = vocoder_model.get_output_details()\n", - " waveform = vocoder_model.get_tensor(output_details[0]['index'])\n", - " return waveform \n", - "\n", - "\n", - "def tts(model, text, CONFIG, p):\n", - " t_1 = time.time()\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None,\n", - " truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n", - " backend='tflite')\n", - " waveform = run_vocoder(mel_postnet_spec.T)\n", - " waveform = waveform[0, 0]\n", - " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", - " tps = (time.time() - t_1) / len(waveform)\n", - " print(waveform.shape)\n", - " print(\" > Run-time: {}\".format(time.time() - t_1))\n", - " print(\" > Real-time factor: {}\".format(rtf))\n", - " print(\" > Time per step: {}\".format(tps))\n", - " IPython.display.display(IPython.display.Audio(waveform, rate=CONFIG.audio['sample_rate'])) \n", - " return alignment, mel_postnet_spec, stop_tokens, waveform" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ZksegYQepkFg" - }, - "source": [ - "### Load TF Models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oVa0kOamprgj" - }, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "import time\n", - "import IPython\n", - "\n", - "from TTS.tts.tf.utils.tflite import load_tflite_model\n", - "from TTS.tts.tf.utils.io import load_checkpoint\n", - "from TTS.utils.io import load_config\n", - "from TTS.tts.utils.text.symbols import symbols, phonemes\n", - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.tts.utils.synthesis import synthesis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "EY-sHVO8IFSH" - }, - "outputs": [], - "source": [ - "# runtime settings\n", - "use_cuda = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "_1aIUp2FpxOQ" - }, - "outputs": [], - "source": [ - "# model paths\n", - "TTS_MODEL = \"data/tts_model.tflite\"\n", - "TTS_CONFIG = \"data/config.json\"\n", - "VOCODER_MODEL = \"data/vocoder_model.tflite\"\n", - "VOCODER_CONFIG = \"data/config_vocoder.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "CpgmdBVQplbv" - }, - "outputs": [], - "source": [ - "# load configs\n", - "TTS_CONFIG = load_config(TTS_CONFIG)\n", - "VOCODER_CONFIG = load_config(VOCODER_CONFIG)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 471 - }, - "colab_type": "code", - "id": "zmrQxiozIUVE", - "outputId": "21cda136-de87-4d55-fd46-7d5306103d90", - "tags": [] - }, - "outputs": [], - "source": [ - "# load the audio processor\n", - "TTS_CONFIG.audio['stats_path'] = 'data/scale_stats.npy'\n", - "ap = AudioProcessor(**TTS_CONFIG.audio) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "8fLoI4ipqMeS" - }, - "outputs": [], - "source": [ - "# LOAD TTS MODEL\n", - "# multi speaker \n", - "speaker_id = None\n", - "speakers = []\n", - "\n", - "# load the models\n", - "model = load_tflite_model(TTS_MODEL)\n", - "vocoder_model = load_tflite_model(VOCODER_MODEL)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Ws_YkPKsLgo-" - }, - "source": [ - "## Run Sample Sentence" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 134 - }, - "colab_type": "code", - "id": "FuWxZ9Ey5Puj", - "outputId": "535c2df1-c27c-458b-e14b-41a977635aa1", - "tags": [] - }, - "outputs": [], - "source": [ - "sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, TTS_CONFIG, ap)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "Tutorial_Converting_PyTorch_to_TF_to_TFlite.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/dataset_analysis/AnalyzeDataset.ipynb b/notebooks/dataset_analysis/AnalyzeDataset.ipynb index c2aabbf9..e08f3ab3 100644 --- a/notebooks/dataset_analysis/AnalyzeDataset.ipynb +++ b/notebooks/dataset_analysis/AnalyzeDataset.ipynb @@ -8,7 +8,7 @@ }, "outputs": [], "source": [ - "TTS_PATH = \"/home/erogol/projects/\"" + "# TTS_PATH = \"/home/erogol/projects/\"" ] }, { @@ -21,7 +21,6 @@ "source": [ "import os\n", "import sys\n", - "sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n", "import librosa\n", "import numpy as np\n", "import pandas as pd\n", @@ -30,6 +29,8 @@ "from multiprocessing import Pool\n", "from matplotlib import pylab as plt\n", "from collections import Counter\n", + "from TTS.config.shared_configs import BaseDatasetConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", "from TTS.tts.datasets.formatters import *\n", "%matplotlib inline" ] @@ -42,22 +43,29 @@ }, "outputs": [], "source": [ - "DATA_PATH = \"/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/\"\n", - "META_DATA = [\"kleinzaches/metadata.csv\",\n", - " \"spiegel_kaetzchen/metadata.csv\",\n", - " \"herrnarnesschatz/metadata.csv\",\n", - " \"maedchen_von_moorhof/metadata.csv\",\n", - " \"koenigsgaukler/metadata.csv\",\n", - " \"altehous/metadata.csv\",\n", - " \"odysseus/metadata.csv\",\n", - " \"undine/metadata.csv\",\n", - " \"reise_tilsit/metadata.csv\",\n", - " \"schmied_seines_glueckes/metadata.csv\",\n", - " \"kammmacher/metadata.csv\",\n", - " \"unterm_birnbaum/metadata.csv\",\n", - " \"liebesbriefe/metadata.csv\",\n", - " \"sandmann/metadata.csv\"]\n", - "NUM_PROC = 8" + "NUM_PROC = 8\n", + "DATASET_CONFIG = BaseDatasetConfig(\n", + " name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/home/ubuntu/TTS/depot/data/male_dataset1_44k/\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n", + " txt_file = os.path.join(root_path, meta_file)\n", + " items = []\n", + " speaker_name = \"maledataset1\"\n", + " with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n", + " for line in ttf:\n", + " cols = line.split(\"|\")\n", + " wav_file = os.path.join(root_path, \"wavs\", cols[0])\n", + " text = cols[1]\n", + " items.append([text, wav_file, speaker_name])\n", + " return items" ] }, { @@ -69,8 +77,10 @@ "outputs": [], "source": [ "# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n", - "items = mailabs(DATA_PATH, META_DATA)\n", - "print(\" > Number of audio files: {}\".format(len(items)))" + "train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n", + "items = train_samples + eval_samples\n", + "print(\" > Number of audio files: {}\".format(len(items)))\n", + "print(items[1])" ] }, { @@ -103,6 +113,15 @@ "print([item for item, count in c.items() if count > 1])" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "item" + ] + }, { "cell_type": "code", "execution_count": null, @@ -112,11 +131,9 @@ "outputs": [], "source": [ "def load_item(item):\n", - " file_name = item[1].strip()\n", " text = item[0].strip()\n", - " audio = librosa.load(file_name, sr=None)\n", - " sr = audio[1]\n", - " audio = audio[0]\n", + " file_name = item[1].strip()\n", + " audio, sr = librosa.load(file_name, sr=None)\n", " audio_len = len(audio) / sr\n", " text_len = len(text)\n", " return file_name, text, text_len, audio, audio_len\n", @@ -374,11 +391,18 @@ "# fequency bar plot - it takes time!!\n", "w_count_df.plot.bar()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -392,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/recipes/README.md b/recipes/README.md index cf3f3de9..21a6727d 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -11,6 +11,12 @@ $ sh ./recipes//download_.sh $ python recipes///train.py ``` +For some datasets you might need to resample the audio files. For example, VCTK dataset can be resampled to 22050Hz as follows. + +```console +python TTS/bin/resample.py --input_dir recipes/vctk/VCTK/wav48_silence_trimmed --output_sr 22050 --output_dir recipes/vctk/VCTK/wav48_silence_trimmed --n_jobs 8 --file_ext flac +``` + If you train a new model using TTS, feel free to share your training to expand the list of recipes. You can also open a new discussion and share your progress with the 🐸 community. \ No newline at end of file diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 68b67d66..f1b29025 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -1,9 +1,12 @@ import os -from TTS.trainer import Trainer, TrainingArgs -from TTS.tts.configs.align_tts_config import AlignTTSConfig, BaseDatasetConfig +from trainer import Trainer, TrainerArgs + +from TTS.tts.configs.align_tts_config import AlignTTSConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.align_tts import AlignTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -31,23 +34,32 @@ config = AlignTTSConfig( datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = AlignTTS(config) +model = AlignTTS(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/download_ljspeech.sh b/recipes/ljspeech/download_ljspeech.sh index 14ef058d..9468988a 100644 --- a/recipes/ljspeech/download_ljspeech.sh +++ b/recipes/ljspeech/download_ljspeech.sh @@ -10,5 +10,5 @@ tar -xjf LJSpeech-1.1.tar.bz2 shuf LJSpeech-1.1/metadata.csv > LJSpeech-1.1/metadata_shuf.csv head -n 12000 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_train.csv tail -n 1100 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_val.csv -mv LJSpeech-1.1 $RUN_DIR/ +mv LJSpeech-1.1 $RUN_DIR/recipes/ljspeech/ rm LJSpeech-1.1.tar.bz2 \ No newline at end of file diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 0a4a965b..a3fc35c9 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -1,10 +1,12 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager @@ -46,9 +48,9 @@ config = FastPitchConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, print_step=50, print_eval=False, mixed_precision=False, @@ -67,23 +69,28 @@ if not config.model_args.use_aligner: f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init the model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index a71da94b..560d3de2 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -1,10 +1,12 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager @@ -45,9 +47,9 @@ config = FastSpeechConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, print_step=50, print_eval=False, mixed_precision=False, @@ -66,23 +68,28 @@ if not config.model_args.use_aligner: f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init the model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 7bd9ea19..c47cd00a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -2,7 +2,7 @@ import os # Trainer: Where the ✨️ happens. # TrainingArgs: Defines the set of arguments of the Trainer. -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs # GlowTTSConfig: all model related values for training, validating and testing. from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # we use the same path as this script as our training folder. @@ -47,7 +48,12 @@ config = GlowTTSConfig( # INITIALIZE THE AUDIO PROCESSOR # Audio processor is used for feature extraction and audio I/O. # It mainly serves to the dataloader and the training loggers. -ap = AudioProcessor(**config.audio.to_dict()) +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) # LOAD DATA SAMPLES # Each sample is a list of ```[text, audio_file_path, speaker_name]``` @@ -60,19 +66,13 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # Models take a config object and a speaker manager as input # Config defines the details of the model like the number of layers, the size of the embedding, etc. # Speaker manager is used by multi-speaker models. -model = GlowTTS(config, speaker_manager=None) +model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # INITIALIZE THE TRAINER # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members. + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀 diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index 8d1c272a..1e5bbf30 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -1,6 +1,7 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +41,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 90c52997..40ff5a00 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -1,6 +1,7 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -40,7 +41,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 6b9683af..7ad132b2 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -1,10 +1,12 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -38,9 +40,9 @@ config = SpeedySpeechConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, print_step=50, print_eval=False, mixed_precision=False, @@ -50,32 +52,32 @@ config = SpeedySpeechConfig( datasets=[dataset_config], ) -# # compute alignments -# if not config.model_args.use_aligner: -# manager = ModelManager() -# model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") -# # TODO: make compute_attention python callable -# os.system( -# f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" -# ) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) -# load training samples +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = ForwardTTS(config) +model = ForwardTTS(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index cf00ccc2..ea1b0874 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer @@ -38,10 +40,16 @@ config = Tacotron2Config( # This is the config that is saved for the future use num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, - ga_alpha=5.0, + ga_alpha=0.0, + decoder_loss_alpha=0.25, + postnet_loss_alpha=0.25, + postnet_diff_spec_alpha=0, + decoder_diff_spec_alpha=0, + decoder_ssim_alpha=0, + postnet_ssim_alpha=0, r=2, attention_type="dynamic_convolution", - double_decoder_consistency=True, + double_decoder_consistency=False, epochs=1000, text_cleaner="phoneme_cleaners", use_phonemes=True, @@ -54,23 +62,35 @@ config = Tacotron2Config( # This is the config that is saved for the future use datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# init model -model = Tacotron2(config) +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = Tacotron2(config, ap, tokenizer) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index b452094a..d00f8ed7 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer @@ -46,6 +48,7 @@ config = Tacotron2Config( # This is the config that is saved for the future use use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, print_step=25, print_eval=True, mixed_precision=False, @@ -56,15 +59,32 @@ config = Tacotron2Config( # This is the config that is saved for the future use # init audio processor ap = AudioProcessor(**config.audio.to_dict()) -# load training samples +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) -# init model -model = Tacotron2(config) +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = Tacotron2(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 589fd027..19c91925 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -1,6 +1,7 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +40,7 @@ model = GAN(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index e86cc861..cfb3351d 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,10 +34,10 @@ audio_config = BaseAudioConfig( config = VitsConfig( audio=audio_config, run_name="vits_ljspeech", - batch_size=48, + batch_size=32, eval_batch_size=16, batch_group_size=5, - num_loader_workers=4, + num_loader_workers=0, num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, @@ -48,28 +50,37 @@ config = VitsConfig( print_step=25, print_eval=True, mixed_precision=True, - max_seq_len=500000, output_path=output_path, datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# config is updated with the default characters if not defined in the config. +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init model -model = Vits(config) +model = Vits(config, ap, tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, ) trainer.fit() diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 6786c052..1abdf45d 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -1,6 +1,7 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -37,7 +38,7 @@ model = Wavegrad(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index f64f5752..640f5092 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,6 +1,7 @@ import os -from TTS.trainer import Trainer, TrainingArgs +from trainer import Trainer, TrainerArgs + from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.datasets.preprocess import load_wav_data @@ -39,7 +40,7 @@ model = Wavernn(config) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index be4747df..ac2c21a2 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -1,8 +1,9 @@ import os from glob import glob +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples @@ -68,8 +69,8 @@ config = VitsConfig( print_eval=False, mixed_precision=False, sort_by_audio_len=True, - min_seq_len=32 * 256 * 4, - max_seq_len=160000, + min_audio_len=32 * 256 * 4, + max_audio_len=160000, output_path=output_path, datasets=dataset_config, characters={ @@ -119,7 +120,7 @@ model = Vits(config, speaker_manager, language_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index f40587e0..986202c5 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_pitch_config import FastPitchConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,6 +34,7 @@ config = FastPitchConfig( num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, + precompute_num_workers=4, compute_f0=True, f0_cache_path=os.path.join(output_path, "f0_cache"), run_eval=True, @@ -39,23 +42,35 @@ config = FastPitchConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +80,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index b2988809..fe785a41 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.fast_speech_config import FastSpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -25,37 +27,48 @@ audio_config = BaseAudioConfig( ) config = FastSpeechConfig( - run_name="fast_pitch_ljspeech", + run_name="fast_speech_vctk", audio=audio_config, batch_size=32, eval_batch_size=16, num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, - compute_f0=True, - f0_cache_path=os.path.join(output_path, "f0_cache"), + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +78,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 8c9f5388..ebdbfb37 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -1,12 +1,14 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # set experiment paths @@ -32,6 +34,7 @@ config = GlowTTSConfig( eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, @@ -45,12 +48,27 @@ config = GlowTTSConfig( output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -60,16 +78,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.num_speakers = speaker_manager.num_speakers # init model -model = GlowTTS(config, speaker_manager) +model = GlowTTS(config, ap, tokenizer, speaker_manager=speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index 81f78d26..80d21ca2 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -1,11 +1,13 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,30 +34,41 @@ config = SpeedySpeechConfig( num_loader_workers=8, num_eval_loader_workers=4, compute_input_seq_cache=True, - compute_f0=True, - f0_cache_path=os.path.join(output_path, "f0_cache"), + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, - max_seq_len=500000, + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=500000, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, ) -# init audio processor -ap = AudioProcessor(**config.audio) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +78,14 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = ForwardTTS(config, speaker_manager) +model = ForwardTTS(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index b0030f17..bed21ad9 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -1,12 +1,14 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron import Tacotron from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -32,6 +34,7 @@ config = TacotronConfig( # This is the config that is saved for the future use eval_batch_size=16, num_loader_workers=4, num_eval_loader_workers=4, + precompute_num_workers=4, run_eval=True, test_delay_epochs=-1, r=6, @@ -45,18 +48,30 @@ config = TacotronConfig( # This is the config that is saved for the future use print_step=25, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=0, - max_seq_len=44000 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -65,16 +80,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron(config, speaker_manager) +model = Tacotron(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index 63efb784..caa745b3 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -1,12 +1,14 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -44,9 +46,10 @@ config = Tacotron2Config( # This is the config that is saved for the future use print_step=150, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=14800, - max_seq_len=22050 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training @@ -60,10 +63,21 @@ config = Tacotron2Config( # This is the config that is saved for the future use lr=3e-5, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -72,16 +86,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron2(config, speaker_manager) +model = Tacotron2(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 346d650b..43f5d4e6 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -1,12 +1,14 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.tacotron2_config import Tacotron2Config from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -44,9 +46,10 @@ config = Tacotron2Config( # This is the config that is saved for the future use print_step=150, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=14800, - max_seq_len=22050 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio + min_text_len=0, + max_text_len=500, + min_audio_len=0, + max_audio_len=44000 * 10, output_path=output_path, datasets=[dataset_config], use_speaker_embedding=True, # set this to enable multi-sepeaker training @@ -60,10 +63,21 @@ config = Tacotron2Config( # This is the config that is saved for the future use lr=3e-5, ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +## INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -72,16 +86,14 @@ speaker_manager = SpeakerManager() speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) # init model -model = Tacotron2(config, speaker_manager) +model = Tacotron2(config, ap, tokenizer, speaker_manager) -# init the trainer and 🚀 +# INITIALIZE THE TRAINER +# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, +# distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) + +# AND... 3,2,1... 🚀 trainer.fit() diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 7eb741c4..dff4eefc 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -1,12 +1,14 @@ import os +from trainer import Trainer, TrainerArgs + from TTS.config.shared_configs import BaseAudioConfig -from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) @@ -56,17 +58,26 @@ config = VitsConfig( print_step=25, print_eval=False, mixed_precision=True, - sort_by_audio_len=True, - min_seq_len=32 * 256 * 4, - max_seq_len=1500000, + max_text_len=325, # change this if you have a larger VRAM than 16GB output_path=output_path, datasets=[dataset_config], ) -# init audio processor -ap = AudioProcessor(**config.audio.to_dict()) +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) -# load training samples +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# config is updated with the default characters if not defined in the config. +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # init speaker manager for multi-speaker training @@ -76,16 +87,15 @@ speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) config.model_args.num_speakers = speaker_manager.num_speakers # init model -model = Vits(config, speaker_manager) +model = Vits(config, ap, tokenizer, speaker_manager) # init the trainer and 🚀 trainer = Trainer( - TrainingArgs(), + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, - training_assets={"audio_processor": ap}, ) trainer.fit() diff --git a/requirements.tf.txt b/requirements.tf.txt deleted file mode 100644 index 8e256a90..00000000 --- a/requirements.tf.txt +++ /dev/null @@ -1 +0,0 @@ -tensorflow==2.5.0 diff --git a/requirements.txt b/requirements.txt index ddb6def9..6e30c26e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,30 +1,38 @@ -cython -flask -gdown -inflect -jieba -librosa==0.8.0 -matplotlib +# core deps numpy==1.19.5 -pandas -pypinyin -pysbd -pyyaml -scipy>=0.19.0 -soundfile -tensorboardX +cython +scipy>=1.4.0 torch>=1.7 -tqdm +torchaudio +soundfile +librosa==0.8.0 numba==0.53 -umap-learn==0.5.1 +inflect +tqdm anyascii -coqpit +pyyaml +fsspec>=2021.04.0 +# deps for examples +flask +# deps for inference +pysbd +# deps for notebooks +umap-learn==0.5.1 +pandas +# deps for training +matplotlib +tensorboardX +pyworld +# coqui stack +trainer @ git+https://github.com/coqui-ai/trainer.git +coqpit # config managemenr +# chinese g2p deps +jieba +pypinyin # japanese g2p deps mecab-python3==1.0.3 unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 -fsspec>=2021.04.0 -pyworld -webrtcvad -torchaudio +# others +webrtcvad # for VAD diff --git a/setup.py b/setup.py index 95f0841b..96173fec 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ # ,+++*. . .*++, ,++*. .*+++* # *+, .,*++**. .**++**. ,+* # .+* *+, -# *+. .+* -# *+* +++ +++ *+* +# *+. Coqui .+* +# *+* +++ TTS +++ *+* # .+++*. . . *+++. # ,+* *+++*... ...*+++* *+, # .++. .""""+++++++****+++++++"""". ++. @@ -35,8 +35,6 @@ if LooseVersion(sys.version) < LooseVersion("3.6") or LooseVersion(sys.version) raise RuntimeError("TTS requires python >= 3.6 and <=3.10 " "but your Python version is {}".format(sys.version)) -cwd = os.path.dirname(os.path.abspath(__file__)) - cwd = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(cwd, "TTS", "VERSION")) as fin: version = fin.read().strip() @@ -65,9 +63,7 @@ with open(os.path.join(cwd, "requirements.notebooks.txt"), "r") as f: requirements_notebooks = f.readlines() with open(os.path.join(cwd, "requirements.dev.txt"), "r") as f: requirements_dev = f.readlines() -with open(os.path.join(cwd, "requirements.tf.txt"), "r") as f: - requirements_tf = f.readlines() -requirements_all = requirements_dev + requirements_notebooks + requirements_tf +requirements_all = requirements_dev + requirements_notebooks with open("README.md", "r", encoding="utf-8") as readme_file: README = readme_file.read() @@ -116,7 +112,6 @@ setup( "all": requirements_all, "dev": requirements_dev, "notebooks": requirements_notebooks, - "tf": requirements_tf, }, python_requires=">=3.6.0, <3.10", entry_points={"console_scripts": ["tts=TTS.bin.synthesize:main", "tts-server = TTS.server.server:main"]}, diff --git a/tests/__init__.py b/tests/__init__.py index 0a0c3379..8906c8c7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -26,6 +26,11 @@ def get_tests_input_path(): return os.path.join(get_tests_path(), "inputs") +def get_tests_data_path(): + """Returns the path to the test data directory.""" + return os.path.join(get_tests_path(), "data") + + def get_tests_output_path(): """Returns the path to the directory for test outputs.""" return os.path.join(get_tests_path(), "outputs") diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py index fa0abe4b..fa740ba3 100644 --- a/tests/aux_tests/test_find_unique_phonemes.py +++ b/tests/aux_tests/test_find_unique_phonemes.py @@ -39,7 +39,6 @@ class TestFindUniquePhonemes(unittest.TestCase): num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -64,7 +63,6 @@ class TestFindUniquePhonemes(unittest.TestCase): num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/aux_tests/test_text_processing.py b/tests/aux_tests/test_text_processing.py deleted file mode 100644 index 62d60a42..00000000 --- a/tests/aux_tests/test_text_processing.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Tests for text to phoneme converstion""" -import unittest - -from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme, text2phone - -# ----------------------------------------------------------------------------- - -LANG = "en-us" - -EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" - -EXPECTED_PHONEMES = "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ| f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s|,| ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l| f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ|!" - -# ----------------------------------------------------------------------------- - - -class TextProcessingTestCase(unittest.TestCase): - """Tests for text to phoneme conversion""" - - def test_phoneme_to_sequence(self): - """Verify en-us sentence phonemes without blank token""" - self._test_phoneme_to_sequence(add_blank=False) - - def test_phoneme_to_sequence_with_blank_token(self): - """Verify en-us sentence phonemes with blank token""" - self._test_phoneme_to_sequence(add_blank=True) - - def _test_phoneme_to_sequence(self, add_blank): - """Verify en-us sentence phonemes""" - text_cleaner = ["phoneme_cleaners"] - sequence = phoneme_to_sequence(EXAMPLE_TEXT, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = EXPECTED_PHONEMES.replace("|", "") - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # multiple punctuations - text = "Be a voice, not an! echo?" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # not ending with punctuation - text = "Be a voice, not an! echo" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # original - text = "Be a voice, not an echo!" - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # extra space after the sentence - text = "Be a voice, not an! echo. " - sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - # extra space after the sentence - text = "Be a voice, not an! echo. " - sequence = phoneme_to_sequence( - text, text_cleaner, LANG, enable_eos_bos=True, add_blank=add_blank, use_espeak_phonemes=True - ) - text_hat = sequence_to_phoneme(sequence) - text_hat_with_params = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" - print(text_hat) - print(len(sequence)) - self.assertEqual(text_hat, text_hat_with_params) - self.assertEqual(text_hat, gt) - - def test_text2phone(self): - """Verify phones directly (with |)""" - ph = text2phone(EXAMPLE_TEXT, LANG, use_espeak_phonemes=True) - self.assertEqual(ph, EXPECTED_PHONEMES) - - -# ----------------------------------------------------------------------------- - -if __name__ == "__main__": - unittest.main() diff --git a/tests/data/ljspeech/f0_cache/pitch_stats.npy b/tests/data/ljspeech/f0_cache/pitch_stats.npy new file mode 100644 index 00000000..aaa385c3 Binary files /dev/null and b/tests/data/ljspeech/f0_cache/pitch_stats.npy differ diff --git a/tests/data_tests/test_dataset_formatters.py b/tests/data_tests/test_dataset_formatters.py index bd83002c..30fb79a8 100644 --- a/tests/data_tests/test_dataset_formatters.py +++ b/tests/data_tests/test_dataset_formatters.py @@ -5,13 +5,13 @@ from tests import get_tests_input_path from TTS.tts.datasets.formatters import common_voice -class TestPreprocessors(unittest.TestCase): +class TestTTSFormatters(unittest.TestCase): def test_common_voice_preprocessor(self): # pylint: disable=no-self-use root_path = get_tests_input_path() meta_file = "common_voice.tsv" items = common_voice(root_path, meta_file) - assert items[0][0] == "The applicants are invited for coffee and visa is given immediately." - assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") + assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately." + assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") - assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts." - assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") + assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts." + assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 19c2e8f7..0562fbf7 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,9 +7,9 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.datasets import TTSDataset -from TTS.tts.datasets.formatters import ljspeech +from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # pylint: disable=unused-variable @@ -18,11 +18,19 @@ OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") os.makedirs(OUTPATH, exist_ok=True) # create a dummy config for testing data loaders. -c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2) +c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2, use_noise_augment=False) c.r = 5 c.data_path = "tests/data/ljspeech/" ok_ljspeech = os.path.exists(c.data_path) +dataset_config = BaseDatasetConfig( + name="ljspeech_test", # ljspeech_test to multi-speaker + meta_file_train="metadata.csv", + meta_file_val=None, + path=c.data_path, + language="en", +) + DATA_EXIST = True if not os.path.exists(c.data_path): DATA_EXIST = False @@ -36,25 +44,26 @@ class TestTTSDataset(unittest.TestCase): self.max_loader_iter = 4 self.ap = AudioProcessor(**c.audio) - def _create_dataloader(self, batch_size, r, bgs): - items = ljspeech(c.data_path, "metadata.csv") + def _create_dataloader(self, batch_size, r, bgs, start_by_longest=False): - # add a default language because now the TTSDataset expect a language - language = "" - items = [[*item, language] for item in items] + # load dataset + meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2) + items = meta_data_train + meta_data_eval + tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( - r, - c.text_cleaner, + outputs_per_step=r, compute_linear_spec=True, return_wav=True, + tokenizer=tokenizer, ap=self.ap, - meta_data=items, - characters=c.characters, + samples=items, batch_group_size=bgs, - min_seq_len=c.min_seq_len, - max_seq_len=float("inf"), - use_phonemes=False, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, + start_by_longest=start_by_longest, ) dataloader = DataLoader( dataset, @@ -67,89 +76,112 @@ class TestTTSDataset(unittest.TestCase): return dataloader, dataset def test_loader(self): - if ok_ljspeech: - dataloader, dataset = self._create_dataloader(2, c.r, 0) - - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] - linear_input = data["linear"] - mel_input = data["mel"] - mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] - wavs = data["waveform"] - - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) - assert isinstance(speaker_name[0], str) - assert linear_input.shape[0] == c.batch_size - assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio["num_mels"] - assert ( - wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length - ), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}" - - # make sure that the computed mels and the waveform match and correctly computed - mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) - ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) - mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] - assert abs(mel_diff.sum()) < 1e-5 - - # check normalization ranges - if self.ap.symmetric_norm: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type - assert mel_input.min() < 0 - else: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= 0 - - def test_batch_group_shuffle(self): - if ok_ljspeech: - dataloader, dataset = self._create_dataloader(2, c.r, 16) - last_length = 0 - frames = dataset.items - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] - linear_input = data["linear"] - mel_input = data["mel"] - mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] - - avg_length = mel_lengths.numpy().mean() - assert avg_length >= last_length - dataloader.dataset.sort_and_filter_items() - is_items_reordered = False - for idx, item in enumerate(dataloader.dataset.items): - if item != frames[idx]: - is_items_reordered = True - break - assert is_items_reordered - - def test_padding_and_spec(self): if ok_ljspeech: dataloader, dataset = self._create_dataloader(1, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] + text_input = data["token_id"] + _ = data["token_id_lengths"] speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] + _ = data["stop_targets"] + _ = data["item_idxs"] + wavs = data["waveform"] + + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + + # check basic conditions + self.assertEqual(check_count, 0) + self.assertEqual(linear_input.shape[0], mel_input.shape[0], c.batch_size) + self.assertEqual(linear_input.shape[2], self.ap.fft_size // 2 + 1) + self.assertEqual(mel_input.shape[2], c.audio["num_mels"]) + self.assertEqual(wavs.shape[1], mel_input.shape[1] * c.audio.hop_length) + self.assertIsInstance(speaker_name[0], str) + + # make sure that the computed mels and the waveform match and correctly computed + mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) + # remove padding in mel-spectrogram + mel_dataloader = mel_input[0].T.numpy()[:, : mel_lengths[0]] + # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding + mel_new = mel_new[:, : mel_lengths[0]] + ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) + mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] + self.assertLess(abs(mel_diff.sum()), 1e-5) + + # check normalization ranges + if self.ap.symmetric_norm: + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual( + mel_input.min(), -self.ap.max_norm # pylint: disable=invalid-unary-operand-type + ) + self.assertLess(mel_input.min(), 0) + else: + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual(mel_input.min(), 0) + + def test_batch_group_shuffle(self): + if ok_ljspeech: + dataloader, dataset = self._create_dataloader(2, c.r, 16) + last_length = 0 + frames = dataset.samples + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + mel_lengths = data["mel_lengths"] + avg_length = mel_lengths.numpy().mean() + dataloader.dataset.preprocess_samples() + is_items_reordered = False + for idx, item in enumerate(dataloader.dataset.samples): + if item != frames[idx]: + is_items_reordered = True + break + self.assertGreaterEqual(avg_length, last_length) + self.assertTrue(is_items_reordered) + + def test_start_by_longest(self): + """Test start_by_longest option. + + Ther first item of the fist batch must be longer than all the other items. + """ + if ok_ljspeech: + dataloader, _ = self._create_dataloader(2, c.r, 0, True) + dataloader.dataset.preprocess_samples() + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + mel_lengths = data["mel_lengths"] + if i == 0: + max_len = mel_lengths[0] + print(mel_lengths) + self.assertTrue(all(max_len >= mel_lengths)) + + def test_padding_and_spectrograms(self): + def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): + self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding + self.assertNotEqual(linear_input[idx, -2].sum(), 0) + self.assertNotEqual(mel_input[idx, -1].sum(), 0) + self.assertNotEqual(mel_input[idx, -2].sum(), 0) + self.assertEqual(stop_target[idx, -1], 1) + self.assertEqual(stop_target[idx, -2], 0) + self.assertEqual(stop_target[idx].sum(), 1) + self.assertEqual(len(mel_lengths.shape), 1) + self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0]) + self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0]) + + if ok_ljspeech: + dataloader, _ = self._create_dataloader(1, 1, 0) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + linear_input = data["linear"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] stop_target = data["stop_targets"] item_idx = data["item_idxs"] @@ -161,7 +193,7 @@ class TestTTSDataset(unittest.TestCase): # NOTE: Below needs to check == 0 but due to an unknown reason # there is a slight difference between two matrices. # TODO: Check this assert cond more in detail. - assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max() + self.assertLess(abs(mel.T - mel_dl).max(), 1e-5) # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() @@ -175,56 +207,36 @@ class TestTTSDataset(unittest.TestCase): self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav") shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav") - # check the last time step to be zero padded - assert linear_input[0, -1].sum() != 0 - assert linear_input[0, -2].sum() != 0 - assert mel_input[0, -1].sum() != 0 - assert mel_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == linear_input[0].shape[0] - assert mel_lengths[0] == mel_input[0].shape[0] + # check the outputs + check_conditions(0, linear_input, mel_input, stop_target, mel_lengths) # Test for batch size 2 - dataloader, dataset = self._create_dataloader(2, 1, 0) + dataloader, _ = self._create_dataloader(2, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] stop_target = data["stop_targets"] item_idx = data["item_idxs"] + # set id to the longest sequence in the batch if mel_lengths[0] > mel_lengths[1]: idx = 0 else: idx = 1 - # check the first item in the batch - assert linear_input[idx, -1].sum() != 0 - assert linear_input[idx, -2].sum() != 0, linear_input - assert mel_input[idx, -1].sum() != 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] - assert mel_lengths[idx] == linear_input[idx].shape[0] + # check the longer item in the batch + check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths) - # check the second itme in the batch - assert linear_input[1 - idx, -1].sum() == 0 - assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1, mel_lengths[1] - 1] == 1 - assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1] - assert len(mel_lengths.shape) == 1 + # check the other item in the batch + self.assertEqual(linear_input[1 - idx, -1].sum(), 0) + self.assertEqual(mel_input[1 - idx, -1].sum(), 0) + self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1) + self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1]) + self.assertEqual(len(mel_lengths.shape), 1) # check batch zero-frame conditions (zero-frame disabled) # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 3d8d6c75..497a3fb5 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -39,7 +39,7 @@ random_sampler = torch.utils.data.RandomSampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 @@ -50,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 diff --git a/tests/inference_tests/test_synthesize.py b/tests/inference_tests/test_synthesize.py index 635506ab..42b77172 100644 --- a/tests/inference_tests/test_synthesize.py +++ b/tests/inference_tests/test_synthesize.py @@ -19,9 +19,9 @@ def test_synthesize(): f'--text "This is an example." --out_path "{output_path}"' ) - # multi-speaker model - run_cli("tts --model_name tts_models/en/vctk/sc-glow-tts --list_speaker_idxs") - run_cli( - f'tts --model_name tts_models/en/vctk/sc-glow-tts --speaker_idx "p304" ' - f'--text "This is an example." --out_path "{output_path}"' - ) + # multi-speaker SC-Glow model + # run_cli("tts --model_name tts_models/en/vctk/sc-glow-tts --list_speaker_idxs") + # run_cli( + # f'tts --model_name tts_models/en/vctk/sc-glow-tts --speaker_idx "p304" ' + # f'--text "This is an example." --out_path "{output_path}"' + # ) diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index 5972dc90..d643cb81 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -1,13 +1,12 @@ import os import unittest +from tests import get_tests_output_path from TTS.config import load_config from TTS.tts.models import setup_model from TTS.utils.io import save_checkpoint from TTS.utils.synthesizer import Synthesizer -from .. import get_tests_output_path - class SynthesizerTest(unittest.TestCase): # pylint: disable=R0201 diff --git a/tests/inputs/test_vocoder_multiband_melgan_config.json b/tests/inputs/test_vocoder_multiband_melgan_config.json index b8b192e4..82afc977 100644 --- a/tests/inputs/test_vocoder_multiband_melgan_config.json +++ b/tests/inputs/test_vocoder_multiband_melgan_config.json @@ -86,7 +86,7 @@ "mel_fmax": null }, - "target_loss": "avg_G_loss", // loss value to pick the best model to save after each epoch + "target_loss": "G_avg_loss", // loss value to pick the best model to save after each epoch // DISCRIMINATOR "discriminator_model": "melgan_multiscale_discriminator", diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py new file mode 100644 index 00000000..8f40656a --- /dev/null +++ b/tests/text_tests/test_characters.py @@ -0,0 +1,174 @@ +import unittest + +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, Graphemes, IPAPhonemes + +# pylint: disable=protected-access + + +class BaseVocabularyTest(unittest.TestCase): + def setUp(self): + self.phonemes = IPAPhonemes() + self.base_vocab = BaseVocabulary( + vocab=self.phonemes._vocab, + pad=self.phonemes.pad, + blank=self.phonemes.blank, + bos=self.phonemes.bos, + eos=self.phonemes.eos, + ) + self.empty_vocab = BaseVocabulary({}) + + def test_pad_id(self): + self.assertEqual(self.empty_vocab.pad_id, 0) + self.assertEqual(self.base_vocab.pad_id, self.phonemes.pad_id) + + def test_blank_id(self): + self.assertEqual(self.empty_vocab.blank_id, 0) + self.assertEqual(self.base_vocab.blank_id, self.phonemes.blank_id) + + def test_vocab(self): + self.assertEqual(self.empty_vocab.vocab, {}) + self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab) + + # def test_init_from_config(self): + # ... + + def test_num_chars(self): + self.assertEqual(self.empty_vocab.num_chars, 0) + self.assertEqual(self.base_vocab.num_chars, self.phonemes.num_chars) + + def test_char_to_id(self): + try: + self.empty_vocab.char_to_id("a") + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + self.assertEqual(self.base_vocab.char_to_id(k), self.phonemes.char_to_id(k)) + + def test_id_to_char(self): + try: + self.empty_vocab.id_to_char(0) + raise Exception("Should have raised KeyError") + except: + pass + for k in self.phonemes.vocab: + v = self.phonemes.char_to_id(k) + self.assertEqual(self.base_vocab.id_to_char(v), self.phonemes.id_to_char(v)) + + +class BaseCharacterTest(unittest.TestCase): + def setUp(self): + self.characters_empty = BaseCharacters("", "", pad="", eos="", bos="", blank="", is_unique=True, is_sorted=True) + + def test_default_character_sets(self): # pylint: disable=no-self-use + """Test initiation of default character sets""" + _ = IPAPhonemes() + _ = Graphemes() + + def test_unique(self): + """Test if the unique option works""" + self.characters_empty.characters = "abcc" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) + + def test_unique_sorted(self): + """Test if the unique and sorted option works""" + self.characters_empty.characters = "cba" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) + + def test_setters_getters(self): + """Test the class setters behaves as expected""" + self.characters_empty.characters = "abc" + self.assertEqual(self.characters_empty._characters, "abc") + self.assertEqual(self.characters_empty.vocab, ["a", "b", "c"]) + + self.characters_empty.punctuations = ".,;:!? " + self.assertEqual(self.characters_empty._punctuations, ".,;:!? ") + self.assertEqual(self.characters_empty.vocab, ["a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.pad = "[PAD]" + self.assertEqual(self.characters_empty._pad, "[PAD]") + self.assertEqual(self.characters_empty.vocab, ["[PAD]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]) + + self.characters_empty.eos = "[EOS]" + self.assertEqual(self.characters_empty._eos, "[EOS]") + self.assertEqual( + self.characters_empty.vocab, ["[PAD]", "[EOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "] + ) + + self.characters_empty.bos = "[BOS]" + self.assertEqual(self.characters_empty._bos, "[BOS]") + self.assertEqual( + self.characters_empty.vocab, ["[PAD]", "[EOS]", "[BOS]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "] + ) + + self.characters_empty.blank = "[BLANK]" + self.assertEqual(self.characters_empty._blank, "[BLANK]") + self.assertEqual( + self.characters_empty.vocab, + ["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "], + ) + self.assertEqual( + self.characters_empty.num_chars, + len(["[PAD]", "[EOS]", "[BOS]", "[BLANK]", "a", "b", "c", ".", ",", ";", ":", "!", "?", " "]), + ) + + self.characters_empty.print_log() + + def test_char_lookup(self): + """Test char to ID and ID to char conversion""" + self.characters_empty.characters = "abc" + self.characters_empty.punctuations = ".,;:!? " + self.characters_empty.pad = "[PAD]" + self.characters_empty.eos = "[EOS]" + self.characters_empty.bos = "[BOS]" + self.characters_empty.blank = "[BLANK]" + + # char to ID + self.assertEqual(self.characters_empty.char_to_id("[PAD]"), 0) + self.assertEqual(self.characters_empty.char_to_id("[EOS]"), 1) + self.assertEqual(self.characters_empty.char_to_id("[BOS]"), 2) + self.assertEqual(self.characters_empty.char_to_id("[BLANK]"), 3) + self.assertEqual(self.characters_empty.char_to_id("a"), 4) + self.assertEqual(self.characters_empty.char_to_id("b"), 5) + self.assertEqual(self.characters_empty.char_to_id("c"), 6) + self.assertEqual(self.characters_empty.char_to_id("."), 7) + self.assertEqual(self.characters_empty.char_to_id(","), 8) + self.assertEqual(self.characters_empty.char_to_id(";"), 9) + self.assertEqual(self.characters_empty.char_to_id(":"), 10) + self.assertEqual(self.characters_empty.char_to_id("!"), 11) + self.assertEqual(self.characters_empty.char_to_id("?"), 12) + self.assertEqual(self.characters_empty.char_to_id(" "), 13) + + # ID to char + self.assertEqual(self.characters_empty.id_to_char(0), "[PAD]") + self.assertEqual(self.characters_empty.id_to_char(1), "[EOS]") + self.assertEqual(self.characters_empty.id_to_char(2), "[BOS]") + self.assertEqual(self.characters_empty.id_to_char(3), "[BLANK]") + self.assertEqual(self.characters_empty.id_to_char(4), "a") + self.assertEqual(self.characters_empty.id_to_char(5), "b") + self.assertEqual(self.characters_empty.id_to_char(6), "c") + self.assertEqual(self.characters_empty.id_to_char(7), ".") + self.assertEqual(self.characters_empty.id_to_char(8), ",") + self.assertEqual(self.characters_empty.id_to_char(9), ";") + self.assertEqual(self.characters_empty.id_to_char(10), ":") + self.assertEqual(self.characters_empty.id_to_char(11), "!") + self.assertEqual(self.characters_empty.id_to_char(12), "?") + self.assertEqual(self.characters_empty.id_to_char(13), " ") diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py new file mode 100644 index 00000000..9b619f6e --- /dev/null +++ b/tests/text_tests/test_phonemizer.py @@ -0,0 +1,208 @@ +import unittest + +from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer + +EXAMPLE_TEXTs = [ + "Recent research at Harvard has shown meditating", + "for as little as 8 weeks can actually increase, the grey matter", + "in the parts of the brain responsible", + "for emotional regulation and learning!", +] + + +EXPECTED_ESPEAK_PHONEMES = [ + "ɹ|ˈiː|s|ə|n|t ɹ|ɪ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ɪ|t|ˌeɪ|ɾ|ɪ|ŋ", + "f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|n|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ", + "ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ɪ|s|p|ˈɑː|n|s|ə|b|əl", + "f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|uː|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!", +] + + +EXPECTED_ESPEAKNG_PHONEMES = [ + "ɹ|ˈiː|s|ə|n|t ɹ|ᵻ|s|ˈɜː|tʃ æ|t h|ˈɑːɹ|v|ɚ|d h|ɐ|z ʃ|ˈoʊ|n m|ˈɛ|d|ᵻ|t|ˌeɪ|ɾ|ɪ|ŋ", + "f|ɔː|ɹ æ|z l|ˈɪ|ɾ|əl æ|z ˈeɪ|t w|ˈiː|k|s k|æ|n ˈæ|k|tʃ|uː|əl|i| ˈɪ|ŋ|k|ɹ|iː|s, ð|ə ɡ|ɹ|ˈeɪ m|ˈæ|ɾ|ɚ", + "ɪ|n|ð|ə p|ˈɑːɹ|t|s ʌ|v|ð|ə b|ɹ|ˈeɪ|n ɹ|ᵻ|s|p|ˈɑː|n|s|ᵻ|b|əl", + "f|ɔː|ɹ ɪ|m|ˈoʊ|ʃ|ə|n|əl ɹ|ˌɛ|ɡ|j|ʊ|l|ˈeɪ|ʃ|ə|n|| æ|n|d l|ˈɜː|n|ɪ|ŋ!", +] + + +class TestEspeakPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ESpeak(language="en-us", backend="espeak") + + for text, ph in zip(EXAMPLE_TEXTs, EXPECTED_ESPEAK_PHONEMES): + phonemes = self.phonemizer.phonemize(text) + self.assertEqual(phonemes, ph) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "espeak") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestEspeakNgPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ESpeak(language="en-us", backend="espeak-ng") + + for text, ph in zip(EXAMPLE_TEXTs, EXPECTED_ESPEAKNG_PHONEMES): + phonemes = self.phonemizer.phonemize(text) + self.assertEqual(phonemes, ph) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "espeak") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestGruutPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False) + self.EXPECTED_PHONEMES = [ + "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", + "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", + "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", + "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!", + ] + + def test_phonemize(self): + for text, ph in zip(EXAMPLE_TEXTs, self.EXPECTED_PHONEMES): + phonemes = self.phonemizer.phonemize(text, separator="|") + self.assertEqual(phonemes, ph) + + # multiple punctuations + text = "Be a voice, not an! echo?" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" + output = self.phonemizer.phonemize(text, separator="|") + output = output.replace("|", "") + self.assertEqual(output, gt) + + # not ending with punctuation + text = "Be a voice, not an! echo" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + # extra space after the sentence + text = "Be a voice, not an! echo. " + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." + output = self.phonemizer.phonemize(text, separator="") + self.assertEqual(output, gt) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "gruut") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), list) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestJA_JPPhonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = JA_JP_Phonemizer() + self._TEST_CASES = """ + どちらに行きますか?/dochiraniikimasuka? + 今日は温泉に、行きます。/kyo:waoNseNni,ikimasu. + 「A」から「Z」までです。/e:karazeqtomadedesu. + そうですね!/so:desune! + クジラは哺乳類です。/kujirawahonyu:ruidesu. + ヴィディオを見ます。/bidioomimasu. + 今日は8月22日です/kyo:wahachigatsuniju:ninichidesu + xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma + 値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu + """ + + def test_phonemize(self): + for line in self._TEST_CASES.strip().split("\n"): + text, phone = line.split("/") + self.assertEqual(self.phonemizer.phonemize(text, separator=""), phone) + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "ja_jp_phonemizer") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) + + +class TestZH_CN_Phonemizer(unittest.TestCase): + def setUp(self): + self.phonemizer = ZH_CN_Phonemizer() + self._TEST_CASES = "" + + def test_phonemize(self): + # TODO: implement ZH phonemizer tests + pass + + def test_name(self): + self.assertEqual(self.phonemizer.name(), "zh_cn_phonemizer") + + def test_get_supported_languages(self): + self.assertIsInstance(self.phonemizer.supported_languages(), dict) + + def test_get_version(self): + self.assertIsInstance(self.phonemizer.version(), str) + + def test_is_available(self): + self.assertTrue(self.phonemizer.is_available()) diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py new file mode 100644 index 00000000..141c10e4 --- /dev/null +++ b/tests/text_tests/test_punctuation.py @@ -0,0 +1,33 @@ +import unittest + +from TTS.tts.utils.text.punctuation import _DEF_PUNCS, Punctuation + + +class PunctuationTest(unittest.TestCase): + def setUp(self): + self.punctuation = Punctuation() + self.test_texts = [ + ("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"), + ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), + ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), + ("This, is my text to be striped from text", "This is my text to be striped from text"), + ] + + def test_get_set_puncs(self): + self.punctuation.puncs = "-=" + self.assertEqual(self.punctuation.puncs, "-=") + + self.punctuation.puncs = _DEF_PUNCS + self.assertEqual(self.punctuation.puncs, _DEF_PUNCS) + + def test_strip_punc(self): + for text, gt in self.test_texts: + text_striped = self.punctuation.strip(text) + self.assertEqual(text_striped, gt) + + def test_strip_restore(self): + for text, gt in self.test_texts: + text_striped, puncs_map = self.punctuation.strip_to_restore(text) + text_restored = self.punctuation.restore(text_striped, puncs_map) + self.assertEqual(" ".join(text_striped), gt) + self.assertEqual(text_restored[0], text) diff --git a/tests/text_tests/test_symbols.py b/tests/text_tests/test_symbols.py deleted file mode 100644 index 49b25986..00000000 --- a/tests/text_tests/test_symbols.py +++ /dev/null @@ -1,8 +0,0 @@ -import unittest - -from TTS.tts.utils.text import phonemes - - -class SymbolsTest(unittest.TestCase): - def test_uniqueness(self): # pylint: disable=no-self-use - assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes))) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py new file mode 100644 index 00000000..908952ea --- /dev/null +++ b/tests/text_tests/test_tokenizer.py @@ -0,0 +1,94 @@ +import unittest +from dataclasses import dataclass + +from coqpit import Coqpit + +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes, _blank, _bos, _eos, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.phonemizers import ESpeak +from TTS.tts.utils.text.tokenizer import TTSTokenizer + + +class TestTTSTokenizer(unittest.TestCase): + def setUp(self): + self.tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + + self.ph = ESpeak("tr", backend="espeak") + self.tokenizer_ph = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) + + def test_encode_decode_graphemes(self): + text = "This is, a test." + ids = self.tokenizer.encode(text) + test_hat = self.tokenizer.decode(ids) + self.assertEqual(text, test_hat) + self.assertEqual(len(ids), len(text)) + + def test_text_to_ids_phonemes(self): + # TODO: note sure how to extend to cover all the languages and phonemizer. + text = "Bu bir Örnek." + text_ph = self.ph.phonemize(text, separator="") + ids = self.tokenizer_ph.text_to_ids(text) + test_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + + def test_text_to_ids_phonemes_with_eos_bos(self): + text = "Bu bir Örnek." + self.tokenizer_ph.use_eos_bos = True + text_ph = IPAPhonemes().bos + self.ph.phonemize(text, separator="") + IPAPhonemes().eos + ids = self.tokenizer_ph.text_to_ids(text) + test_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + + def test_text_to_ids_phonemes_with_eos_bos_and_blank(self): + text = "Bu bir Örnek." + self.tokenizer_ph.use_eos_bos = True + self.tokenizer_ph.add_blank = True + text_ph = "bʊ bɪr œrnˈɛc." + ids = self.tokenizer_ph.text_to_ids(text) + text_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, text_hat) + + def test_print_logs(self): + self.tokenizer.print_logs() + self.tokenizer_ph.print_logs() + + def test_not_found_characters(self): + self.ph = ESpeak("en-us") + tokenizer_local = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=self.ph) + self.assertEqual(len(self.tokenizer.not_found_characters), 0) + text = "Yolk of one egg beaten light" + ids = tokenizer_local.text_to_ids(text) + text_hat = tokenizer_local.ids_to_text(ids) + self.assertEqual(tokenizer_local.not_found_characters, ["̩"]) + self.assertEqual(text_hat, "jˈoʊk ʌv wˈʌn ˈɛɡ bˈiːʔn lˈaɪt") + + def test_init_from_config(self): + @dataclass + class Characters(Coqpit): + characters_class: str = None + characters: str = _phonemes + punctuations: str = _punctuations + pad: str = _pad + eos: str = _eos + bos: str = _bos + blank: str = _blank + is_unique: bool = True + is_sorted: bool = True + + @dataclass + class TokenizerConfig(Coqpit): + enable_eos_bos_chars: bool = True + use_phonemes: bool = True + add_blank: bool = False + characters: str = Characters() + phonemizer: str = "espeak" + phoneme_language: str = "tr" + text_cleaner: str = "phoneme_cleaners" + characters = Characters() + + tokenizer_ph, _ = TTSTokenizer.init_from_config(TokenizerConfig()) + tokenizer_ph.phonemizer.backend = "espeak" + text = "Bu bir Örnek." + text_ph = "" + self.ph.phonemize(text, separator="") + "" + ids = tokenizer_ph.text_to_ids(text) + test_hat = tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index f5d60d7c..85dfbbcb 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig @@ -47,6 +49,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py new file mode 100644 index 00000000..37faf449 --- /dev/null +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -0,0 +1,83 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.fast_pitch_config import FastPitchConfig + +config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = FastPitchConfig( + audio=audio_config, + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + use_speaker_embedding=True, + test_sentences=[ + "Be a voice, not an echo.", + ], +) +config.audio.do_trim_silence = True +config.use_speaker_embedding = True +config.model_args.use_speaker_embedding = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 71ba8b25..d2d78af4 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -2,11 +2,13 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json") +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") audio_config = BaseAudioConfig( @@ -41,8 +43,11 @@ config = FastPitchConfig( test_sentences=[ "Be a voice, not an echo.", ], + use_speaker_embedding=False, ) config.audio.do_trim_silence = True +config.use_speaker_embedding = False +config.model_args.use_speaker_embedding = False config.audio.trim_db = 60 config.save_json(config_path) @@ -57,11 +62,20 @@ command_train = ( "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " "--coqpit.test_delay_epochs 0" ) + run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 82d0ec3b..2783e4bd 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -4,11 +4,13 @@ import unittest import torch from torch import optim +from trainer.logging.tensorboard_logger import TensorboardLogger -from tests import get_tests_input_path +from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor # pylint: disable=unused-variable @@ -21,6 +23,7 @@ c = GlowTTSConfig() ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") +BATCH_SIZE = 3 def count_parameters(model): @@ -28,36 +31,247 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -class GlowTTSTrainTest(unittest.TestCase): +class TestGlowTTS(unittest.TestCase): @staticmethod - def test_train_step(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def test_init_multispeaker(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config) + # speaker embedding with default speaker_embedding_dim + config.use_speaker_embedding = True + config.num_speakers = 5 + config.d_vector_dim = None + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + # use external speaker embeddings with speaker_embedding_dim = 301 + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 301 + model = GlowTTS(config) + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, 301) + # use speaker embedddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_speaker_embedding = True + config.speakers_file = os.path.join(get_tests_data_path(), "ljspeech", "speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + # use external speaker embeddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 256 + config.d_vector_file = os.path.join(get_tests_data_path(), "dummy_speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + + def test_unlock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.unlock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertFalse(f.initialized) + + def test_lock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.lock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertTrue(f.initialized) + + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward(self): + self._test_forward(1) + self._test_forward(3) + + def _test_forward_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_d_vector(self): + self._test_forward_with_d_vector(1) + self._test_forward_with_d_vector(3) + + def _test_forward_with_speaker_id(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_speaker_id(self): + self._test_forward_with_speaker_id(1) + self._test_forward_with_speaker_id(3) + + def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): + output_shape = outputs["model_outputs"].shape + self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2]) + self.assertEqual(outputs["logdet"], None) + self.assertEqual(outputs["y_mean"].shape, output_shape) + self.assertEqual(outputs["y_log_scale"].shape, output_shape) + self.assertEqual(outputs["alignments"].shape, output_shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,)) + + def _test_inference(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference(self): + self._test_inference(1) + self._test_inference(3) + + def _test_inference_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_d_vector(self): + self._test_inference_with_d_vector(1) + self._test_inference_with_d_vector(3) + + def _test_inference_with_speaker_ids(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_speaker_ids(self): + self._test_inference_with_speaker_ids(1) + self._test_inference_with_speaker_ids(3) + + def _test_inference_with_MAS(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + # inference encoder and decoder with MAS + y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) + y2 = model.decoder_inference(mel_spec, mel_lengths) + assert ( + y2["model_outputs"].shape == y["model_outputs"].shape + ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( + y["model_outputs"].shape, y2["model_outputs"].shape + ) + + def test_inference_with_MAS(self): + self._test_inference_with_MAS(1) + self._test_inference_with_MAS(3) + + def test_train_step(self): + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) criterion = GlowTTSLoss() - # model to train config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) - # reference model to compare model weights model_ref = GlowTTS(config).to(device) - model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) - # pass the state to ref model model_ref.load_state_dict(copy.deepcopy(model.state_dict())) - count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 - optimizer = optim.Adam(model.parameters(), lr=0.001) for _ in range(5): optimizer.zero_grad() @@ -75,40 +289,90 @@ class GlowTTSTrainTest(unittest.TestCase): loss = loss_dict["loss"] loss.backward() optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 + self._check_parameter_changes(model, model_ref) - -class GlowTTSInferenceTest(unittest.TestCase): - @staticmethod - def test_inference(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) - - # create model + def test_train_eval_log(self): + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = mel_lengths + batch["mel_input"] = mel_spec + batch["d_vectors"] = None + batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS(config).to(device) - - model.eval() - print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) - - # inference encoder and decoder with MAS - y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) - - y2 = model.decoder_inference(mel_spec, mel_lengths) - - assert ( - y2["model_outputs"].shape == y["model_outputs"].shape - ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( - y["model_outputs"].shape, y2["model_outputs"].shape + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log" ) + criterion = model.get_criterion() + outputs, _ = model.train_step(batch, criterion) + model.train_log(batch, outputs, logger, None, 1) + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + + def test_test_run(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.eval() + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) + + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) + + def test_get_criterion(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + + config = GlowTTSConfig(num_chars=32, num_speakers=2) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(not hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.c_in_channels == config.d_vector_dim) diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests/test_glow_tts_d-vectors_train.py new file mode 100644 index 00000000..14f9e4d2 --- /dev/null +++ b/tests/tts_tests/test_glow_tts_d-vectors_train.py @@ -0,0 +1,70 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.glow_tts_config import GlowTTSConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = GlowTTSConfig( + batch_size=2, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + data_dep_init_steps=1.0, + use_speaker_embedding=False, + use_d_vector_file=True, + d_vector_file="tests/data/ljspeech/speakers.json", + d_vector_dim=256, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = config.d_vector_file + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py new file mode 100644 index 00000000..c327332e --- /dev/null +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -0,0 +1,67 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.glow_tts_config import GlowTTSConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = GlowTTSConfig( + batch_size=2, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + data_dep_init_steps=1.0, + use_speaker_embedding=True, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index e5901076..b0acf004 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -16,7 +18,6 @@ config = GlowTTSConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -49,6 +50,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 6a2f260d..23bb440a 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -1,6 +1,6 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask def average_over_durations_test(): # pylint: disable=no-self-use @@ -39,6 +39,34 @@ def segment_test(): for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum() + try: + segments = segment(x, segment_ids, segment_size=10) + raise Exception("Should have failed") + except: + pass + + segments = segment(x, segment_ids, segment_size=10, pad_short=True) + for idx, start_indx in enumerate(segment_ids): + assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() + + +def rand_segments_test(): + x = T.rand(2, 3, 4) + x_lens = T.randint(3, 4, (2,)) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) + assert segments.shape == (2, 3, 3) + assert all(seg_idxs >= 0), seg_idxs + try: + segments, _ = rand_segments(x, x_lens, segment_size=5) + raise Exception("Should have failed") + except: + pass + x_lens_back = x_lens.clone() + segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) + assert segments.shape == (2, 3, 5) + assert all(seg_idxs >= 0), seg_idxs + assert all(x_lens_back == x_lens) + def generate_path_test(): durations = T.randint(1, 4, (10, 21)) diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 7d7ecc7c..9a26d253 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig @@ -47,6 +49,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example for it.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index c817badc..6b003f2c 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config @@ -52,6 +54,16 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = config.d_vector_file + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 095016d8..b9f4de0b 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config @@ -27,6 +29,7 @@ config = Tacotron2Config( "Be a voice, not an echo.", ], use_speaker_embedding=True, + num_speakers=4, max_decoder_steps=50, ) @@ -49,6 +52,16 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_tacotron2_tf_model.py b/tests/tts_tests/test_tacotron2_tf_model.py deleted file mode 100644 index fb1efcde..00000000 --- a/tests/tts_tests/test_tacotron2_tf_model.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -import unittest - -import numpy as np -import tensorflow as tf -import torch - -from TTS.tts.configs.tacotron2_config import Tacotron2Config -from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model - -tf.get_logger().setLevel("INFO") - - -# pylint: disable=unused-variable - -torch.manual_seed(1) -use_cuda = torch.cuda.is_available() -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -c = Tacotron2Config() - - -class TacotronTFTrainTest(unittest.TestCase): - @staticmethod - def generate_dummy_inputs(): - chars_seq = torch.randint(0, 24, (8, 128)).long().to(device) - chars_seq_lengths = torch.randint(100, 128, (8,)).long().to(device) - chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) - - chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy()) - chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy()) - mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) - return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_train_step(self): - """test forward pass""" - ( - chars_seq, - chars_seq_lengths, - mel_spec, - mel_postnet_spec, - mel_lengths, - stop_targets, - speaker_ids, - ) = self.generate_dummy_inputs() - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) - # training pass - output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) - - # check model output shapes - assert np.all(output[0].shape == mel_spec.shape) - assert np.all(output[1].shape == mel_spec.shape) - assert output[2].shape[2] == chars_seq.shape[1] - assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) - assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) - - # inference pass - output = model(chars_seq, training=False) - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_forward_attention( - self, - ): - ( - chars_seq, - chars_seq_lengths, - mel_spec, - mel_postnet_spec, - mel_lengths, - stop_targets, - speaker_ids, - ) = self.generate_dummy_inputs() - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True) - # training pass - output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) - - # check model output shapes - assert np.all(output[0].shape == mel_spec.shape) - assert np.all(output[1].shape == mel_spec.shape) - assert output[2].shape[2] == chars_seq.shape[1] - assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) - assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) - - # inference pass - output = model(chars_seq, training=False) - - @unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.") - def test_tflite_conversion( - self, - ): # pylint:disable=no-self-use - model = Tacotron2( - num_chars=24, - num_speakers=0, - r=3, - out_channels=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=0, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=True, - ) - model.build_inference() - convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True) - # init tflite model - tflite_model = load_tflite_model("test_tacotron2.tflite") - # fake input - inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) # pylint:disable=unexpected-keyword-arg - # run inference - # get input and output details - input_details = tflite_model.get_input_details() - output_details = tflite_model.get_output_details() - # reshape input tensor for the new input shape - tflite_model.resize_tensor_input( - input_details[0]["index"], inputs.shape - ) # pylint:disable=unexpected-keyword-arg - tflite_model.allocate_tensors() - detail = input_details[0] - input_shape = detail["shape"] - tflite_model.set_tensor(detail["index"], inputs) - # run the tflite_model - tflite_model.invoke() - # collect outputs - decoder_output = tflite_model.get_tensor(output_details[0]["index"]) - postnet_output = tflite_model.get_tensor(output_details[1]["index"]) - # remove tflite binary - os.remove("test_tacotron2.tflite") diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 4f37ef89..8c30d9f9 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config @@ -47,6 +49,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 68071c66..40cd2d3d 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig @@ -48,6 +50,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 4274d947..384234e5 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -1,17 +1,20 @@ +import copy import os import unittest import torch +from trainer.logging.tensorboard_logger import TensorboardLogger -from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path +from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.utils.speakers import SpeakerManager LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") torch.manual_seed(1) @@ -21,6 +24,38 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # pylint: disable=no-self-use class TestVits(unittest.TestCase): + def test_load_audio(self): + wav, sr = load_audio(WAV_FILE) + self.assertEqual(wav.shape, (1, 41885)) + self.assertEqual(sr, 22050) + + spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) + mel = wav_to_mel( + wav, + n_fft=1024, + num_mels=80, + sample_rate=sr, + hop_length=512, + win_length=1024, + fmin=0, + fmax=8000, + center=False, + ) + mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) + + self.assertEqual((mel - mel2).abs().max(), 0) + self.assertEqual(spec.shape[0], mel.shape[0]) + self.assertEqual(spec.shape[2], mel.shape[2]) + + spec_db = amp_to_db(spec) + spec_amp = db_to_amp(spec_db) + + self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4) + + def test_dataset(self): + """TODO:""" + ... + def test_init_multispeaker(self): num_speakers = 10 args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) @@ -100,35 +135,36 @@ class TestVits(unittest.TestCase): self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len)) self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len)) - def _init_inputs(self, config): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(self, config, batch_size=2): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device) - spec_lengths = torch.randint(20, 30, (8,)).long().to(device) + spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device) + mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device) + spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) - waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device) - return input_dummy, input_lengths, spec, spec_lengths, waveform + waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device) + return input_dummy, input_lengths, mel, spec, spec_lengths, waveform - def _check_forward_outputs(self, config, output_dict, encoder_config=None): + def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2): self.assertEqual( output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) - self.assertEqual(output_dict["alignments"].shape, (8, 128, 30)) + self.assertEqual(output_dict["alignments"].shape, (batch_size, 128, 30)) self.assertEqual(output_dict["alignments"].max(), 1) self.assertEqual(output_dict["alignments"].min(), 0) - self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) self.assertEqual( output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) if encoder_config: - self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) - self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) else: self.assertEqual(output_dict["gt_spk_emb"], None) self.assertEqual(output_dict["syn_spk_emb"], None) @@ -137,7 +173,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -148,7 +184,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -157,16 +193,36 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict) + def test_d_vector_forward(self): + batch_size = 2 + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.train() + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + d_vectors = torch.randn(batch_size, 256).to(device) + output_dict = model.forward( + input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} + ) + self._check_forward_outputs(config, output_dict) + def test_multilingual_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config).to(device) output_dict = model.forward( @@ -182,6 +238,7 @@ class TestVits(unittest.TestCase): def test_secl_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) speaker_encoder_config.model_params["use_torch_spec"] = True @@ -198,9 +255,9 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config, speaker_manager=speaker_manager).to(device) output_dict = model.forward( @@ -213,28 +270,237 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict, speaker_encoder_config) + def _check_inference_outputs(self, config, outputs, input_dummy, batch_size=1): + feat_len = outputs["z"].shape[2] + self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension + self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len)) + self.assertEqual(outputs["z"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["z_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["m_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["logs_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + def test_inference(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multispeaker_inference(self): num_speakers = 10 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) - speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multilingual_inference(self): num_speakers = 10 num_langs = 3 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + model = Vits(config).to(device) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) lang_ids = torch.randint(0, num_langs, (1,)).long().to(device) - model = Vits(config).to(device) _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference( + input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids} + ) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + def test_d_vector_inference(self): + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.eval() + # batch size = 1 + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + d_vectors = torch.randn(1, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy) + # batch size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config) + d_vectors = torch.randn(2, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=2) + + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()): + name = item1[0] + param = item1[1] + param_ref = item2[1] + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + name, param.shape, param, param_ref + ) + count = count + 1 + + def _create_batch(self, config, batch_size): + input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size) + batch = {} + batch["tokens"] = input_dummy + batch["token_lens"] = input_lengths + batch["spec_lens"] = mel_lengths + batch["mel_lens"] = mel_lengths + batch["spec"] = spec + batch["mel"] = mel + batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device) + batch["d_vectors"] = None + batch["speaker_ids"] = None + batch["language_ids"] = None + return batch + + def test_train_step(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_eval_log(self): + batch_size = 2 + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + batch = self._create_batch(config, batch_size) + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log" + ) + criterion = model.get_criterion() + criterion = [criterion[0].to(device), criterion[1].to(device)] + outputs = [None] * 2 + outputs[0], _ = model.train_step(batch, criterion, 0) + outputs[1], _ = model.train_step(batch, criterion, 1) + model.train_log(batch, outputs, logger, None, 1) + + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + + def test_test_run(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.eval() + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) + + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) + + def test_get_criterion(self): + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(not hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig( + model_args=VitsArgs( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) + ) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig( + model_args=VitsArgs( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + ) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py index 213669f5..5fd9cbc1 100644 --- a/tests/tts_tests/test_vits_d-vectors_train.py +++ b/tests/tts_tests/test_vits_d-vectors_train.py @@ -16,7 +16,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py similarity index 75% rename from tests/tts_tests/test_vits_multilingual_train.py rename to tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 50cccca5..0c7672d7 100644 --- a/tests/tts_tests/test_vits_multilingual_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig @@ -33,7 +35,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -82,6 +83,18 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech" +languae_id = "en" +continue_speakers_path = os.path.join(continue_path, "speakers.json") +continue_languages_path = os.path.join(continue_path, "language_ids.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --language_ids_file_path {continue_languages_path} --language_idx {languae_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index 1ca57d93..a8e2020e 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig @@ -11,7 +13,7 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs") dataset_config_en = BaseDatasetConfig( - name="ljspeech", + name="ljspeech_test", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", @@ -19,7 +21,7 @@ dataset_config_en = BaseDatasetConfig( ) dataset_config_pt = BaseDatasetConfig( - name="ljspeech", + name="ljspeech_test", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", @@ -31,7 +33,7 @@ config = VitsConfig( eval_batch_size=2, num_loader_workers=0, num_eval_loader_workers=0, - text_cleaner="english_cleaners", + text_cleaner="multilingual_cleaners", use_phonemes=False, phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -85,6 +87,18 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +languae_id = "en" +continue_speakers_path = config.d_vector_file +continue_languages_path = os.path.join(continue_path, "language_ids.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --language_ids_file_path {continue_languages_path} --language_idx {languae_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 6cc1dabd..c928cee4 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig @@ -16,7 +18,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -25,7 +26,7 @@ config = VitsConfig( print_step=1, print_eval=True, test_sentences=[ - ["Be a voice, not an echo.", "ljspeech"], + ["Be a voice, not an echo.", "ljspeech-1"], ], ) # set audio config @@ -45,7 +46,7 @@ config.save_json(config_path) command_train = ( f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.name ljspeech_test " "--coqpit.datasets.0.meta_file_train metadata.csv " "--coqpit.datasets.0.meta_file_val metadata.csv " "--coqpit.datasets.0.path tests/data/ljspeech " @@ -57,6 +58,16 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index 607f7b29..003f99a8 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -2,6 +2,8 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig @@ -16,7 +18,6 @@ config = VitsConfig( num_eval_loader_workers=0, text_cleaner="english_cleaners", use_phonemes=True, - use_espeak_phonemes=True, phoneme_language="en-us", phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", run_eval=True, @@ -48,6 +49,14 @@ run_cli(command_train) # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + # restore the model and continue training for one more epoch command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " run_cli(command_train) diff --git a/tests/vocoder_tests/test_multiband_melgan_train.py b/tests/vocoder_tests/test_multiband_melgan_train.py index c49107bd..80027607 100644 --- a/tests/vocoder_tests/test_multiband_melgan_train.py +++ b/tests/vocoder_tests/test_multiband_melgan_train.py @@ -20,6 +20,7 @@ config = MultibandMelganConfig( eval_split_size=1, print_step=1, print_eval=True, + steps_to_start_discriminator=1, data_path="tests/data/ljspeech", discriminator_model_params={"base_channels": 16, "max_channels": 64, "downsample_factors": [4, 4, 4]}, output_path=output_path, diff --git a/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py b/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py deleted file mode 100644 index 225ceaf5..00000000 --- a/tests/vocoder_tests/test_vocoder_tf_melgan_generator.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest - -import numpy as np -import tensorflow as tf -import torch - -from TTS.vocoder.tf.models.melgan_generator import MelganGenerator - -use_cuda = torch.cuda.is_available() - - -@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.") -def test_melgan_generator(): - hop_length = 256 - model = MelganGenerator() - # pylint: disable=no-value-for-parameter - dummy_input = tf.random.uniform((4, 80, 64)) - output = model(dummy_input, training=False) - assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape diff --git a/tests/vocoder_tests/test_vocoder_tf_pqmf.py b/tests/vocoder_tests/test_vocoder_tf_pqmf.py deleted file mode 100644 index 6acb20d9..00000000 --- a/tests/vocoder_tests/test_vocoder_tf_pqmf.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import unittest - -import soundfile as sf -import tensorflow as tf -import torch -from librosa.core import load - -from tests import get_tests_input_path, get_tests_output_path, get_tests_path -from TTS.vocoder.tf.layers.pqmf import PQMF - -TESTS_PATH = get_tests_path() -WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") -use_cuda = torch.cuda.is_available() - - -@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.") -def test_pqmf(): - w, sr = load(WAV_FILE) - - layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) - w, sr = load(WAV_FILE) - w2 = tf.convert_to_tensor(w[None, None, :]) - b2 = layer.analysis(w2) - w2_ = layer.synthesis(b2) - w2_ = w2.numpy() - - print(w2_.max()) - print(w2_.min()) - print(w2_.mean()) - sf.write(os.path.join(get_tests_output_path(), "tf_pqmf_output.wav"), w2_.flatten(), sr) diff --git a/tests/vocoder_tests/test_vocoder_wavernn.py b/tests/vocoder_tests/test_vocoder_wavernn.py index d4a7b8dd..966ea3dd 100644 --- a/tests/vocoder_tests/test_vocoder_wavernn.py +++ b/tests/vocoder_tests/test_vocoder_wavernn.py @@ -46,6 +46,6 @@ def test_wavernn(): config.model_args.mode = 4 model = Wavernn(config) output = model(dummy_x, dummy_m) - assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape + assert np.all(output.shape == (2, 1280, 2**4)), output.shape output = model.inference(dummy_y, True, 5500, 550) assert np.all(output.shape == (256 * (y_size - 1),))