Merge pull request #1288 from coqui-ai/dev

v0.6.0
This commit is contained in:
Eren Gölge 2022-03-07 14:05:30 +01:00 committed by GitHub
commit 209ee40c88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
208 changed files with 6477 additions and 6934 deletions

View File

@ -1,16 +0,0 @@
#!/bin/bash
yes | apt-get install sox
yes | apt-get install ffmpeg
yes | apt-get install tmux
yes | apt-get install zsh
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
sudo sh install.sh
# pip install pytorch==1.7.0+cu100
# python3 setup.py develop
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
# python3 distribute.py --config_path config.json
while true; do sleep 1000000; done

46
.github/workflows/data_tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: data-tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make data_tests

46
.github/workflows/inference_tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: inference_tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make inference_tests

48
.github/workflows/text_tests.yml vendored Normal file
View File

@ -0,0 +1,48 @@
name: tts-tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make test_text

View File

@ -35,6 +35,8 @@ jobs:
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps make system-deps
- name: Install/upgrade Python setup deps - name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel run: python3 -m pip install --upgrade pip setuptools wheel

View File

@ -35,6 +35,7 @@ jobs:
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y git make gcc sudo apt-get install -y git make gcc
sudo apt-get install espeak espeak-ng
make system-deps make system-deps
- name: Install/upgrade Python setup deps - name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel run: python3 -m pip install --upgrade pip setuptools wheel

1
.gitignore vendored
View File

@ -165,3 +165,4 @@ internal/*
*_phoneme.npy *_phoneme.npy
wandb wandb
depot/* depot/*
coqui_recipes/*

View File

@ -168,7 +168,8 @@ disable=missing-docstring,
exception-escape, exception-escape,
comprehension-escape, comprehension-escape,
duplicate-code, duplicate-code,
not-callable not-callable,
import-outside-toplevel
# Enable the message, report, category or checker with the given id(s). You can # Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option

View File

@ -26,6 +26,15 @@ test_aux: ## run aux tests.
test_zoo: ## run zoo tests. test_zoo: ## run zoo tests.
nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id
inference_tests: ## run inference tests.
nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id
data_tests: ## run data tests.
nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id
test_text: ## run text tests.
nosetests tests.text_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.text_tests --nologcapture --with-id
test_failed: ## only run tests failed the last time. test_failed: ## only run tests failed the last time.
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed
@ -41,7 +50,6 @@ system-deps: ## install linux system deps
dev-deps: ## install development deps dev-deps: ## install development deps
pip install -r requirements.dev.txt pip install -r requirements.dev.txt
pip install -r requirements.tf.txt
doc-deps: ## install docs dependencies doc-deps: ## install docs dependencies
pip install -r docs/requirements.txt pip install -r docs/requirements.txt

View File

@ -61,8 +61,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Detailed training logs on the terminal and Tensorboard. - Detailed training logs on the terminal and Tensorboard.
- Support for Multi-speaker TTS. - Support for Multi-speaker TTS.
- Efficient, flexible, lightweight but feature complete `Trainer API`. - Efficient, flexible, lightweight but feature complete `Trainer API`.
- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference. - Released and ready-to-use models.
- Released and read-to-use models.
- Tools to curate Text2Speech datasets under```dataset_analysis```. - Tools to curate Text2Speech datasets under```dataset_analysis```.
- Utilities to use and test your models. - Utilities to use and test your models.
- Modular (but not too much) code base enabling easy implementation of new ideas. - Modular (but not too much) code base enabling easy implementation of new ideas.
@ -113,17 +112,11 @@ If you are only interested in [synthesizing speech](https://tts.readthedocs.io/e
pip install TTS pip install TTS
``` ```
By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra.
```bash
pip install TTS[tf]
```
If you plan to code or train models, clone 🐸TTS and install it locally. If you plan to code or train models, clone 🐸TTS and install it locally.
```bash ```bash
git clone https://github.com/coqui-ai/TTS git clone https://github.com/coqui-ai/TTS
pip install -e .[all,dev,notebooks,tf] # Select the relevant extras pip install -e .[all,dev,notebooks] # Select the relevant extras
``` ```
If you are on Ubuntu (Debian), you can also run following commands for installation. If you are on Ubuntu (Debian), you can also run following commands for installation.
@ -204,12 +197,10 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|- train*.py (train your target model.) |- train*.py (train your target model.)
|- distribute.py (train your TTS model using Multiple GPUs.) |- distribute.py (train your TTS model using Multiple GPUs.)
|- compute_statistics.py (compute dataset statistics for normalization.) |- compute_statistics.py (compute dataset statistics for normalization.)
|- convert*.py (convert target torch model to TF.)
|- ... |- ...
|- tts/ (text to speech models) |- tts/ (text to speech models)
|- layers/ (model layer definitions) |- layers/ (model layer definitions)
|- models/ (model definitions) |- models/ (model definitions)
|- tf/ (Tensorflow 2 utilities and model implementations)
|- utils/ (model specific utilities.) |- utils/ (model specific utilities.)
|- speaker_encoder/ (Speaker Encoder models.) |- speaker_encoder/ (Speaker Encoder models.)
|- (same) |- (same)

View File

@ -4,7 +4,7 @@
"multi-dataset":{ "multi-dataset":{
"your_tts":{ "your_tts":{
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"default_vocoder": null, "default_vocoder": null,
"commit": "e9a1953e", "commit": "e9a1953e",
"license": "CC BY-NC-ND 4.0", "license": "CC BY-NC-ND 4.0",
@ -33,7 +33,7 @@
}, },
"tacotron2-DDC_ph": { "tacotron2-DDC_ph": {
"description": "Tacotron2 with Double Decoder Consistency with phonemes.", "description": "Tacotron2 with Double Decoder Consistency with phonemes.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip",
"default_vocoder": "vocoder_models/en/ljspeech/univnet", "default_vocoder": "vocoder_models/en/ljspeech/univnet",
"commit": "3900448", "commit": "3900448",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
@ -71,7 +71,7 @@
}, },
"vits": { "vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.", "description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--vits.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--vits.zip",
"default_vocoder": null, "default_vocoder": null,
"commit": "3900448", "commit": "3900448",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
@ -89,18 +89,9 @@
} }
}, },
"vctk": { "vctk": {
"sc-glow-tts": {
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
"default_vocoder": "vocoder_models/en/vctk/hifigan_v2",
"commit": "b531fa69",
"author": "Edresson Casanova",
"license": "",
"contact": ""
},
"vits": { "vits": {
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.", "description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--vctk--vits.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--vits.zip",
"default_vocoder": null, "default_vocoder": null,
"commit": "3900448", "commit": "3900448",
"author": "Eren @erogol", "author": "Eren @erogol",
@ -109,7 +100,7 @@
}, },
"fast_pitch":{ "fast_pitch":{
"description": "FastPitch model trained on VCTK dataseset.", "description": "FastPitch model trained on VCTK dataseset.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--fast_pitch.zip",
"default_vocoder": null, "default_vocoder": null,
"commit": "bdab788d", "commit": "bdab788d",
"author": "Eren @erogol", "author": "Eren @erogol",
@ -156,7 +147,7 @@
"uk":{ "uk":{
"mai": { "mai": {
"glow-tts": { "glow-tts": {
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--uk--mailabs--glow-tts.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--uk--mai--glow-tts.zip",
"author":"@robinhad", "author":"@robinhad",
"commit": "bdab788d", "commit": "bdab788d",
"license": "MIT", "license": "MIT",
@ -168,7 +159,7 @@
"zh-CN": { "zh-CN": {
"baker": { "baker": {
"tacotron2-DDC-GST": { "tacotron2-DDC-GST": {
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
"commit": "unknown", "commit": "unknown",
"author": "@kirianguiller", "author": "@kirianguiller",
"default_vocoder": null "default_vocoder": null
@ -206,6 +197,52 @@
"commit": "401fbd89" "commit": "401fbd89"
} }
} }
},
"tr":{
"common-voice": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--tr--common-voice--glow-tts.zip",
"default_vocoder": "vocoder_models/tr/common-voice/hifigan",
"license": "MIT",
"description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.",
"author": "Fatih Akademi",
"commit": null
}
}
},
"it": {
"mai_female": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--glow-tts.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
},
"vits":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--vits.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
}
},
"mai_male": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--glow-tts.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
},
"vits":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--vits.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
}
}
} }
}, },
"vocoder_models": { "vocoder_models": {
@ -324,6 +361,17 @@
"contact": "" "contact": ""
} }
} }
},
"tr":{
"common-voice": {
"hifigan":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/vocoder_models--tr--common-voice--hifigan.zip",
"description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.",
"author": "Fatih Akademi",
"license": "MIT",
"commit": null
}
}
} }
} }
} }

View File

@ -1 +1 @@
0.5.0 0.6.0

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_checkpoint from TTS.utils.io import load_checkpoint

View File

@ -29,6 +29,9 @@ parser.add_argument(
help="Path to dataset config file.", help="Path to dataset config file.",
) )
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument(
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
@ -40,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
wav_files = meta_data_train + meta_data_eval wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager( speaker_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
) )
# compute speaker embeddings # compute speaker embeddings
@ -52,11 +58,15 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
else: else:
speaker_name = None speaker_name = None
# extract the embedding wav_file_name = os.path.basename(wav_file)
embedd = speaker_manager.compute_d_vector_from_clip(wav_file) if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
# get the embedding from the old file
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
else:
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
# create speaker_mapping if target dataset is defined # create speaker_mapping if target dataset is defined
wav_file_name = os.path.basename(wav_file)
speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]["name"] = speaker_name speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["embedding"] = embedd speaker_mapping[wav_file_name]["embedding"] = embedd

View File

@ -51,7 +51,7 @@ def main():
N = 0 N = 0
for item in tqdm(dataset_items): for item in tqdm(dataset_items):
# compute features # compute features
wav = ap.load_wav(item if isinstance(item, str) else item[1]) wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
linear = ap.spectrogram(wav) linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav) mel = ap.melspectrogram(wav)
@ -59,13 +59,13 @@ def main():
N += mel.shape[1] N += mel.shape[1]
mel_sum += mel.sum(1) mel_sum += mel.sum(1)
linear_sum += linear.sum(1) linear_sum += linear.sum(1)
mel_square_sum += (mel ** 2).sum(axis=1) mel_square_sum += (mel**2).sum(axis=1)
linear_square_sum += (linear ** 2).sum(axis=1) linear_square_sum += (linear**2).sum(axis=1)
mel_mean = mel_sum / N mel_mean = mel_sum / N
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2) mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
linear_mean = linear_sum / N linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
output_file_path = args.out_path output_file_path = args.out_path
stats = {} stats = {}

View File

@ -1,25 +0,0 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary
import argparse
from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.generic_utils import setup_generator
from TTS.vocoder.tf.utils.io import load_checkpoint
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
parser = argparse.ArgumentParser()
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants
CONFIG = load_config(args.config_path)
# load the model
model = setup_generator(CONFIG)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
# create tflite model
tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)

View File

@ -1,105 +0,0 @@
import argparse
import os
from difflib import SequenceMatcher
import numpy as np
import tensorflow as tf
import torch
from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf,
convert_tf_name,
transfer_weights_torch_to_tf,
)
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator
# prevent GPU use
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# define args
parser = argparse.ArgumentParser()
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0
# init torch model
model = setup_generator(c)
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()
# init tf model
model_tf = setup_tf_generator(c)
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
mel_pred = model_tf(dummy_input, training=False)
# get tf variables
tf_vars = model_tf.weights
# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
var_map = []
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))
# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
# Compare TF and TORCH models
# check embedding outputs
model.eval()
dummy_input_torch = torch.ones((1, 80, 10))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
out_torch = model.layers[0](dummy_input_torch)
out_tf = model_tf.model_layers[0](dummy_input_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
assert compare_torch_tf(out_torch, out_tf_) < 1e-5
for i in range(1, len(model.layers)):
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
out_torch = model.layers[i](out_torch)
out_tf = model_tf.model_layers[i](out_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
diff = compare_torch_tf(out_torch, out_tf_)
assert diff < 1e-5, diff
torch.manual_seed(0)
dummy_input_torch = torch.rand((1, 80, 100))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
model.inference_padding = 0
model_tf.inference_padding = 0
output_torch = model.inference(dummy_input_torch)
output_tf = model_tf(dummy_input_tf, training=False)
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf)
# save tf model
save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -1,30 +0,0 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary
import argparse
from TTS.tts.tf.utils.generic_utils import setup_model
from TTS.tts.tf.utils.io import load_checkpoint
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
parser = argparse.ArgumentParser()
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants
CONFIG = load_config(args.config_path)
# load the model
c = CONFIG
num_speakers = 0
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)
# create tflite model
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)

View File

@ -1,187 +0,0 @@
import argparse
import os
import sys
from difflib import SequenceMatcher
from pprint import pprint
import numpy as np
import tensorflow as tf
import torch
from TTS.tts.models import setup_model
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
parser = argparse.ArgumentParser()
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0
# init torch model
model = setup_model(c)
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
# init tf model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_tf = Tacotron2(
num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
)
# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
var_map = [
("embedding/embeddings:0", "embedding.weight"),
("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"),
("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"),
("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"),
("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"),
("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")),
(
"encoder/lstm/backward_lstm/lstm_cell_2/bias:0",
("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"),
),
("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"),
("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"),
("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"),
]
# %%
# get tf_model graph
model_tf.build_inference()
# get tf variables
tf_vars = model_tf.weights
# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))
pprint(var_map)
pprint(torch_var_names)
# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
# Compare TF and TORCH models
# %%
# check embedding outputs
model.eval()
input_ids = torch.randint(0, 24, (1, 128)).long()
o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
# compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1, 2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
# pylint: disable=redefined-builtin
# compare decoder.attention_rnn
inp = torch.rand([1, 768])
inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5
query = output
inputs = torch.rand([1, 128, 512])
query_tf = query.detach().numpy()
inputs_tf = inputs.numpy()
# compare decoder.attention
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
assert compare_torch_tf(context, context_tf) < 1e-5
# compare decoder.decoder_rnn
input = torch.rand([1, 1536])
input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.linear_projection
input = torch.rand([1, 1536])
input_tf = input.numpy()
output = model.decoder.linear_projection(input)
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
assert compare_torch_tf(output, output_tf) < 1e-5
# compare decoder outputs
model.decoder.max_decoder_steps = 100
model_tf.decoder.set_max_decoder_steps(100)
output, align, stop = model.decoder.inference(oo_en)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
# compare the whole model output
outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
# %%
# save tf model
save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -7,15 +7,14 @@ import subprocess
import time import time
import torch import torch
from trainer import TrainerArgs
from TTS.trainer import TrainingArgs
def main(): def main():
""" """
Call train.py as a new process and pass command arguments Call train.py as a new process and pass command arguments
""" """
parser = TrainingArgs().init_argparse(arg_prefix="") parser = TrainerArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.") parser.add_argument("--script", type=str, help="Target training script to distibute.")
args, unargs = parser.parse_known_args() args, unargs = parser.parse_known_args()

View File

@ -13,6 +13,7 @@ from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False): def setup_loader(ap, r, verbose=False):
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset( dataset = TTSDataset(
r, outputs_per_step=r,
c.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,
meta_data=meta_data, samples=meta_data,
tokenizer=tokenizer,
ap=ap, ap=ap,
characters=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0, batch_group_size=0,
min_seq_len=c.min_seq_len, min_text_len=c.min_text_len,
max_seq_len=c.max_seq_len, max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len,
phoneme_cache_path=c.phoneme_cache_path, phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes, precompute_num_workers=0,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False, use_noise_augment=False,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None, speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers) dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False)) dataset.preprocess_samples()
loader = DataLoader( loader = DataLoader(
dataset, dataset,
@ -75,8 +75,8 @@ def set_filename(wav_path, out_path):
def format_data(data): def format_data(data):
# setup input data # setup input data
text_input = data["text"] text_input = data["token_id"]
text_lengths = data["text_lengths"] text_lengths = data["token_id_lengths"]
mel_input = data["mel"] mel_input = data["mel"]
mel_lengths = data["mel_lengths"] mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"] item_idx = data["item_idxs"]
@ -138,7 +138,7 @@ def inference(
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
) )
model_output = outputs["model_outputs"] model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy() model_output = model_output.detach().cpu().numpy()
elif "tacotron" in model_name: elif "tacotron" in model_name:
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
# load data instances # load data instances
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval) meta_data_train, meta_data_eval = load_tts_samples(
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
# use eval and training partitions # use eval and training partitions
meta_data = meta_data_train + meta_data_eval meta_data = meta_data_train + meta_data_eval

View File

@ -23,7 +23,10 @@ def main():
c = load_config(args.config_path) c = load_config(args.config_path)
# load all datasets # load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) train_items, eval_items = load_tts_samples(
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
items = train_items + eval_items items = train_items + eval_items
texts = "".join(item[0] for item in items) texts = "".join(item[0] for item in items)

View File

@ -7,14 +7,15 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text import text2phone from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
phonemizer = Gruut(language="en-us")
def compute_phonemes(item): def compute_phonemes(item):
try: try:
text = item[0] text = item[0]
language = item[-1] ph = phonemizer.phonemize(text).split("|")
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
except: except:
return [] return []
return list(set(ph)) return list(set(ph))
@ -39,10 +40,17 @@ def main():
c = load_config(args.config_path) c = load_config(args.config_path)
# load all datasets # load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) train_items, eval_items = load_tts_samples(
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
items = train_items + eval_items items = train_items + eval_items
print("Num items:", len(items)) print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = [] phones = []
for ph in phonemes: for ph in phonemes:

View File

@ -26,6 +26,7 @@ if __name__ == "__main__":
--input_dir /root/LJSpeech-1.1/ --input_dir /root/LJSpeech-1.1/
--output_sr 22050 --output_sr 22050
--output_dir /root/resampled_LJSpeech-1.1/ --output_dir /root/resampled_LJSpeech-1.1/
--file_ext wav
--n_jobs 24 --n_jobs 24
""", """,
formatter_class=RawTextHelpFormatter, formatter_class=RawTextHelpFormatter,
@ -55,6 +56,14 @@ if __name__ == "__main__":
help="Path of the destination folder. If not defined, the operation is done in place", help="Path of the destination folder. If not defined, the operation is done in place",
) )
parser.add_argument(
"--file_ext",
type=str,
default="wav",
required=False,
help="Extension of the audio files to resample",
)
parser.add_argument( parser.add_argument(
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
) )
@ -67,7 +76,7 @@ if __name__ == "__main__":
args.input_dir = args.output_dir args.input_dir = args.output_dir
print("Resampling the audio files...") print("Resampling the audio files...")
audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True) audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True)
print(f"Found {len(audio_files)} files...") print(f"Found {len(audio_files)} files...")
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
with Pool(processes=args.n_jobs) as p: with Pool(processes=args.n_jobs) as p:

View File

@ -8,6 +8,7 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
@ -19,7 +20,7 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True

View File

@ -1,19 +1,22 @@
import os import os
import torch from dataclasses import dataclass, field
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from trainer import Trainer, TrainerArgs
from TTS.trainer import Trainer, TrainingArgs
from TTS.config import load_config, register_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor @dataclass
class TrainTTSArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
# init trainer args # init trainer args
train_args = TrainingArgs() train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args # override trainer args from comman-line args
@ -41,45 +44,15 @@ def main():
config = register_config(config_base.model)() config = register_config(config_base.model)()
# load training samples # load training samples
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True) train_samples, eval_samples = load_tts_samples(
config.datasets,
# setup audio processor eval_split=True,
ap = AudioProcessor(**config.audio) eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
# init speaker manager )
if check_config_and_model_args(config, "use_speaker_embedding", True):
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
elif check_config_and_model_args(config, "use_d_vector_file", True):
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
speaker_manager = SpeakerManager(
d_vectors_file_path=config.model_args.d_vector_file,
encoder_model_path=config.model_args.speaker_encoder_model_path,
encoder_config_path=config.model_args.speaker_encoder_config_path,
use_cuda=torch.cuda.is_available(),
)
else:
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
config.num_speakers = speaker_manager.num_speakers
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
speaker_manager = None
if check_config_and_model_args(config, "use_language_embedding", True):
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
else:
config.num_languages = language_manager.num_languages
else:
language_manager = None
# init the model from config # init the model from config
model = setup_model(config, speaker_manager, language_manager) model = setup_model(config, train_samples + eval_samples)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
@ -89,7 +62,6 @@ def main():
model=model, model=model,
train_samples=train_samples, train_samples=train_samples,
eval_samples=eval_samples, eval_samples=eval_samples,
training_assets={"audio_processor": ap},
parse_command_line_args=False, parse_command_line_args=False,
) )
trainer.fit() trainer.fit()

View File

@ -1,16 +1,23 @@
import os import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model from TTS.vocoder.models import setup_model
@dataclass
class TrainVocoderArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
# init trainer args # init trainer args
train_args = TrainingArgs() train_args = TrainVocoderArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args # override trainer args from comman-line args

View File

@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass
from typing import List from typing import List
from coqpit import Coqpit, check_argument from coqpit import Coqpit, check_argument
from trainer import TrainerConfig
@dataclass @dataclass
@ -57,6 +58,12 @@ class BaseAudioConfig(Coqpit):
do_amp_to_db_mel (bool, optional): do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
pitch_fmax (float, optional):
Maximum frequency of the F0 frames. Defaults to ```640```.
pitch_fmin (float, optional):
Minimum frequency of the F0 frames. Defaults to ```0```.
trim_db (int): trim_db (int):
Silence threshold used for silence trimming. Defaults to 45. Silence threshold used for silence trimming. Defaults to 45.
@ -135,6 +142,9 @@ class BaseAudioConfig(Coqpit):
spec_gain: int = 20 spec_gain: int = 20
do_amp_to_db_linear: bool = True do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True do_amp_to_db_mel: bool = True
# f0 params
pitch_fmax: float = 640.0
pitch_fmin: float = 0.0
# normalization params # normalization params
signal_norm: bool = True signal_norm: bool = True
min_level_db: int = -100 min_level_db: int = -100
@ -228,130 +238,24 @@ class BaseDatasetConfig(Coqpit):
@dataclass @dataclass
class BaseTrainingConfig(Coqpit): class BaseTrainingConfig(TrainerConfig):
"""Base config to define the basic training parameters that are shared """Base config to define the basic 🐸TTS training parameters that are shared
among all the models. among all the models. It is based on ```Trainer.TrainingConfig```.
Args: Args:
model (str): model (str):
Name of the model that is used in the training. Name of the model that is used in the training.
run_name (str):
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
run_description (str):
Short description of the experiment.
epochs (int):
Number training epochs. Defaults to 10000.
batch_size (int):
Training batch size.
eval_batch_size (int):
Validation batch size.
mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time.
print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```.
print_step (int):
Number of steps required to print the next training log.
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard.
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```.
project_name (str):
Name of the project. Defaults to config.model
wandb_entity (str):
Name of W&B entity/team. Enables collaboration across a team or org.
log_model_step (int):
Number of steps required to log a checkpoint as W&B artifact
save_step (int):ipt
Number of steps required to save the next checkpoint.
checkpoint (bool):
Enable / Disable checkpointing.
keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```.
keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000.
num_loader_workers (int): num_loader_workers (int):
Number of workers for training time dataloader. Number of workers for training time dataloader.
num_eval_loader_workers (int): num_eval_loader_workers (int):
Number of workers for evaluation time dataloader. Number of workers for evaluation time dataloader.
output_path (str):
Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
""" """
model: str = None model: str = None
run_name: str = "coqui_tts"
run_description: str = ""
# training params
epochs: int = 10000
batch_size: int = None
eval_batch_size: int = None
mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params
run_eval: bool = True
test_delay_epochs: int = 0
print_eval: bool = False
# logging
dashboard_logger: str = "tensorboard"
print_step: int = 25
plot_step: int = 100
model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
# checkpointing
save_step: int = 10000
checkpoint: bool = True
keep_all_best: bool = False
keep_after: int = 10000
# dataloading # dataloading
num_loader_workers: int = 0 num_loader_workers: int = 0
num_eval_loader_workers: int = 0 num_eval_loader_workers: int = 0
use_noise_augment: bool = False use_noise_augment: bool = False
use_language_weighted_sampler: bool = False use_language_weighted_sampler: bool = False
# paths
output_path: str = None
# distributed
distributed_backend: str = "nccl"
distributed_url: str = "tcp://localhost:54321"

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -9,28 +8,21 @@ from torch import nn
# pylint: skip-file # pylint: skip-file
class BaseModel(nn.Module, ABC): class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
Notes on input/output tensor shapes: @staticmethod
Any input or output tensor of the model must be shaped as @abstractmethod
def init_from_config(config: Coqpit):
"""Init the model from given config.
- 3D tensors `batch x time x channels` Override this depending on your model.
- 2D tensors `batch x channels` """
- 1D tensors `batch x 1` ...
"""
def __init__(self, config: Coqpit):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Set model arguments from the config. Override this."""
pass
@abstractmethod @abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training. """Forward ... for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is intended to be You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model. used by `train_step()` without exposing it out of the model.
@ -48,7 +40,7 @@ class BaseModel(nn.Module, ABC):
@abstractmethod @abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict: def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference. """Forward ... for inference.
We don't use `*kwargs` since it is problematic with the TorchScript API. We don't use `*kwargs` since it is problematic with the TorchScript API.
@ -63,9 +55,25 @@ class BaseModel(nn.Module, ABC):
... ...
return outputs_dict return outputs_dict
def format_batch(self, batch: Dict) -> Dict:
"""Format batch returned by the data loader before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
def format_batch_on_device(self, batch: Dict) -> Dict:
"""Format batch on device before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
@abstractmethod @abstractmethod
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses. """Perform a single training step. Run the model forward ... and compute losses.
Args: Args:
batch (Dict): Input tensors. batch (Dict): Input tensors.
@ -93,11 +101,11 @@ class BaseModel(nn.Module, ABC):
Returns: Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform. Tuple[Dict, np.ndarray]: training plots and output waveform.
""" """
pass ...
@abstractmethod @abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
call `train_step()` with no changes. call `train_step()` with no changes.
Args: Args:
@ -114,36 +122,49 @@ class BaseModel(nn.Module, ABC):
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`""" """The same as `train_log()`"""
pass ...
@abstractmethod @abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference. """Load a checkpoint and get ready for training or inference.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file. checkpoint_path (str): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
""" """
... ...
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: @staticmethod
"""Setup an return optimizer or optimizers.""" @abstractmethod
pass def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
"""Init the model from given config.
def get_lr(self) -> Union[float, List[float]]: Override this depending on your model.
"""Return learning rate(s).
Returns:
Union[float, List[float]]: Model's initial learning rates.
""" """
pass ...
def get_scheduler(self, optimizer: torch.optim.Optimizer): @abstractmethod
pass def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
):
...
def get_criterion(self): # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
pass # """Setup an return optimizer or optimizers."""
# ...
def format_batch(self): # def get_lr(self) -> Union[float, List[float]]:
pass # """Return learning rate(s).
# Returns:
# Union[float, List[float]]: Model's initial learning rates.
# """
# ...
# def get_scheduler(self, optimizer: torch.optim.Optimizer):
# ...
# def get_criterion(self):
# ...

View File

@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
if args.vocoder_name is not None and not args.vocoder_path: if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custome model paths # CASE3: set custom model paths
if args.model_path is not None: if args.model_path is not None:
model_path = args.model_path model_path = args.model_path
config_path = args.config_path config_path = args.config_path
@ -170,9 +170,9 @@ def tts():
text = request.args.get("text") text = request.args.get("text")
speaker_idx = request.args.get("speaker_id", "") speaker_idx = request.args.get("speaker_id", "")
style_wav = request.args.get("style_wav", "") style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav) style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text)) print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
out = io.BytesIO() out = io.BytesIO()
synthesizer.save_wav(wavs, out) synthesizer.save_wav(wavs, out)

View File

@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
mel = self.ap.melspectrogram(wav).astype("float32") mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len # sample seq_len
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx]["audio_file"]
sample = { sample = {
"mel": mel, "mel": mel,
"item_idx": self.items[idx][1], "item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name, "speaker_name": speaker_name,
} }
return sample return sample
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
def __parse_items(self): def __parse_items(self):
self.speaker_to_utters = {} self.speaker_to_utters = {}
for i in self.items: for i in self.items:
path_ = i[1] path_ = i["audio_file"]
speaker_ = i[2] speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys(): if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_) self.speaker_to_utters[speaker_].append(path_)
else: else:

View File

@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2) x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP": elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2) mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1) x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1) x = x.view(x.size()[0], -1)

View File

@ -113,7 +113,7 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio): def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample( noise_list = random.sample(
self.noise_list[noise_type], self.noise_list[noise_type],
@ -135,7 +135,7 @@ class AugmentWAV(object):
self.additive_noise_config[noise_type]["min_snr_in_db"], self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"], self.additive_noise_config[noise_type]["max_num_noises"],
) )
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None: if noises_wav is None:
@ -154,7 +154,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files) rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2)) rir = rir / np.sqrt(np.sum(rir**2))
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio): def apply_one(self, audio):

View File

@ -1,19 +1,24 @@
import os import os
from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.trainer import TrainingArgs from TTS.tts.utils.text.characters import parse_symbols
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from TTS.utils.logging import init_dashboard_logger
from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.trainer_utils import get_last_checkpoint @dataclass
class TrainArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def getarguments(): def getarguments():
train_config = TrainingArgs() train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="") parser = train_config.init_argparse(arg_prefix="")
return parser return parser
@ -75,13 +80,13 @@ def process_args(args, config=None):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config) dashboard_logger = logger_factory(config, experiment_path)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments(): def init_arguments():
train_config = TrainingArgs() train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="") parser = train_config.init_argparse(arg_prefix="")
return parser return parser

File diff suppressed because it is too large Load Diff

View File

@ -89,11 +89,11 @@ class FastPitchConfig(BaseTTSConfig):
pitch_loss_alpha (float): pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float): binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int): binary_loss_warmup_epochs (float):
Start binary alignment loss after this many steps. Defaults to 20000. Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int): min_seq_len (int):
Minimum input sequence length to be used at training. Minimum input sequence length to be used at training.
@ -129,12 +129,12 @@ class FastPitchConfig(BaseTTSConfig):
duration_loss_type: str = "mse" duration_loss_type: str = "mse"
use_ssim_loss: bool = True use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0 ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0 pitch_loss_alpha: float = 0.1
binary_align_loss_start_step: int = 20000 dur_loss_alpha: float = 0.1
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150
# overrides # overrides
min_seq_len: int = 13 min_seq_len: int = 13

View File

@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig):
binary_loss_alpha (float): binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int): binary_loss_warmup_epochs (float):
Start binary alignment loss after this many steps. Defaults to 20000. Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int): min_seq_len (int):
Minimum input sequence length to be used at training. Minimum input sequence length to be used at training.
@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig):
pitch_loss_alpha: float = 0.0 pitch_loss_alpha: float = 0.0
aligner_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000 binary_loss_warmup_epochs: int = 150
# overrides # overrides
min_seq_len: int = 13 min_seq_len: int = 13

View File

@ -153,6 +153,7 @@ class GlowTTSConfig(BaseTTSConfig):
# multi-speaker settings # multi-speaker settings
use_speaker_embedding: bool = False use_speaker_embedding: bool = False
speakers_file: str = None
use_d_vector_file: bool = False use_d_vector_file: bool = False
d_vector_file: str = False d_vector_file: str = False

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import List from typing import Dict, List
from coqpit import Coqpit, check_argument from coqpit import Coqpit, check_argument
@ -50,9 +50,16 @@ class GSTConfig(Coqpit):
@dataclass @dataclass
class CharactersConfig(Coqpit): class CharactersConfig(Coqpit):
"""Defines character or phoneme set used by the model """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses.
Args: Args:
characters_class (str):
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
the configuration. Defaults to None.
vocab_dict (dict):
Defines the vocabulary dictionary used to encode the characters. Defaults to None.
pad (str): pad (str):
characters in place of empty padding. Defaults to None. characters in place of empty padding. Defaults to None.
@ -62,6 +69,9 @@ class CharactersConfig(Coqpit):
bos (str): bos (str):
characters showing the beginning of a sentence. Defaults to None. characters showing the beginning of a sentence. Defaults to None.
blank (str):
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
characters (str): characters (str):
character set used by the model. Characters not in this list are ignored when converting input text to character set used by the model. Characters not in this list are ignored when converting input text to
a list of sequence IDs. Defaults to None. a list of sequence IDs. Defaults to None.
@ -70,32 +80,32 @@ class CharactersConfig(Coqpit):
characters considered as punctuation as parsing the input sentence. Defaults to None. characters considered as punctuation as parsing the input sentence. Defaults to None.
phonemes (str): phonemes (str):
characters considered as parsing phonemes. Defaults to None. characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new
models. Defaults to None.
unique (bool): is_unique (bool):
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
models trained with character lists with duplicates. models trained with character lists with duplicates. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
""" """
characters_class: str = None
# using BaseVocabulary
vocab_dict: Dict = None
# using on BaseCharacters
pad: str = None pad: str = None
eos: str = None eos: str = None
bos: str = None bos: str = None
blank: str = None
characters: str = None characters: str = None
punctuations: str = None punctuations: str = None
phonemes: str = None phonemes: str = None
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
is_sorted: bool = True
def check_values(
self,
):
"""Check config fields"""
c = asdict(self)
check_argument("pad", c, prerequest="characters", restricted=True)
check_argument("eos", c, prerequest="characters", restricted=True)
check_argument("bos", c, prerequest="characters", restricted=True)
check_argument("characters", c, prerequest="characters", restricted=True)
check_argument("phonemes", c, restricted=True)
check_argument("punctuations", c, prerequest="characters", restricted=True)
@dataclass @dataclass
@ -110,8 +120,13 @@ class BaseTTSConfig(BaseTrainingConfig):
use_phonemes (bool): use_phonemes (bool):
enable / disable phoneme use. enable / disable phoneme use.
use_espeak_phonemes (bool): phonemizer (str):
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`). Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`.
Defaults to None.
phoneme_language (str):
Language code for the phonemizer. You can check the list of supported languages by running
`python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None.
compute_input_seq_cache (bool): compute_input_seq_cache (bool):
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
@ -144,11 +159,19 @@ class BaseTTSConfig(BaseTrainingConfig):
sort_by_audio_len (bool): sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int): min_text_len (int):
Minimum sequence length to be used at training. Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
max_seq_len (int): max_text_len (int):
Maximum sequence length to be used at training. Larger values result in more VRAM usage. Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
min_audio_len (int):
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
max_audio_len (int):
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
OOM error in training. Defaults to float("inf").
compute_f0 (int): compute_f0 (int):
(Not in use yet). (Not in use yet).
@ -156,9 +179,16 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_linear_spec (bool): compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data. If True data loader computes and returns linear spectrograms alongside the other data.
precompute_num_workers (int):
Number of workers to precompute features. Defaults to 0.
use_noise_augment (bool): use_noise_augment (bool):
Augment the input audio with random noise. Augment the input audio with random noise.
start_by_longest (bool):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
add_blank (bool): add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence. of slower run-time due to the longer input sequence.
@ -183,12 +213,19 @@ class BaseTTSConfig(BaseTrainingConfig):
test_sentences (List[str]): test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]' List of sentences to be used at testing. Defaults to '[]'
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
""" """
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# phoneme settings # phoneme settings
use_phonemes: bool = False use_phonemes: bool = False
use_espeak_phonemes: bool = True phonemizer: str = None
phoneme_language: str = None phoneme_language: str = None
compute_input_seq_cache: bool = False compute_input_seq_cache: bool = False
text_cleaner: str = None text_cleaner: str = None
@ -197,17 +234,21 @@ class BaseTTSConfig(BaseTrainingConfig):
phoneme_cache_path: str = None phoneme_cache_path: str = None
# vocabulary parameters # vocabulary parameters
characters: CharactersConfig = None characters: CharactersConfig = None
add_blank: bool = False
# training params # training params
batch_group_size: int = 0 batch_group_size: int = 0
loss_masking: bool = None loss_masking: bool = None
# dataloading # dataloading
sort_by_audio_len: bool = False sort_by_audio_len: bool = False
min_seq_len: int = 1 min_audio_len: int = 1
max_seq_len: int = float("inf") max_audio_len: int = float("inf")
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False compute_f0: bool = False
compute_linear_spec: bool = False compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False use_noise_augment: bool = False
add_blank: bool = False start_by_longest: bool = False
# dataset # dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer # optimizer
@ -218,3 +259,6 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler_params: dict = field(default_factory=lambda: {}) lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing # testing
test_sentences: List[str] = field(default_factory=lambda: []) test_sentences: List[str] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01

View File

@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
binary_loss_alpha (float): binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int): binary_loss_warmup_epochs (float):
Start binary alignment loss after this many steps. Defaults to 20000. Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int): min_seq_len (int):
Minimum input sequence length to be used at training. Minimum input sequence length to be used at training.
@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
spec_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 0.3 binary_align_loss_alpha: float = 0.3
binary_align_loss_start_step: int = 50000 binary_loss_warmup_epochs: int = 150
# overrides # overrides
min_seq_len: int = 13 min_seq_len: int = 13

View File

@ -83,6 +83,8 @@ class TacotronConfig(BaseTTSConfig):
ddc_r (int): ddc_r (int):
reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this
as a multiple of the `r` value. Defaults to 6. as a multiple of the `r` value. Defaults to 6.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
use_speaker_embedding (bool): use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False. in the multi-speaker mode. Defaults to False.
@ -176,6 +178,7 @@ class TacotronConfig(BaseTTSConfig):
ddc_r: int = 6 ddc_r: int = 6
# multi-speaker settings # multi-speaker settings
speakers_file: str = None
use_speaker_embedding: bool = False use_speaker_embedding: bool = False
speaker_embedding_dim: int = 512 speaker_embedding_dim: int = 512
use_d_vector_file: bool = False use_d_vector_file: bool = False

View File

@ -17,7 +17,7 @@ class VitsConfig(BaseTTSConfig):
Model architecture arguments. Defaults to `VitsArgs()`. Model architecture arguments. Defaults to `VitsArgs()`.
grad_clip (List): grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`. Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
lr_gen (float): lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002. Initial learning rate for the generator. Defaults to 0.0002.
@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool): compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int): r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -130,9 +121,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True compute_linear_spec: bool = True
# overrides # overrides
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
add_blank: bool = True add_blank: bool = True

View File

@ -9,25 +9,48 @@ from TTS.tts.datasets.dataset import *
from TTS.tts.datasets.formatters import * from TTS.tts.datasets.formatters import *
def split_dataset(items): def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args: Args:
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. <<<<<<< HEAD
items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
=======
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
>>>>>>> Fix docstring
""" """
speakers = [item[-1] for item in items] speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1 is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01)) if eval_split_size > 1:
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." eval_split_size = int(eval_split_size)
else:
if eval_split_max_size:
eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size))
else:
eval_split_size = int(len(items) * eval_split_size)
assert (
eval_split_size > 0
), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(
1 / len(items)
)
np.random.seed(0) np.random.seed(0)
np.random.shuffle(items) np.random.shuffle(items)
if is_multi_speaker: if is_multi_speaker:
items_eval = [] items_eval = []
speakers = [item[-1] for item in items] speakers = [item["speaker_name"] for item in items]
speaker_counter = Counter(speakers) speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size: while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items)) item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1] speaker_to_be_removed = items[item_idx]["speaker_name"]
if speaker_counter[speaker_to_be_removed] > 1: if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx]) items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1 speaker_counter[speaker_to_be_removed] -= 1
@ -37,7 +60,11 @@ def split_dataset(items):
def load_tts_samples( def load_tts_samples(
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None datasets: Union[List[Dict], Dict],
eval_split=True,
formatter: Callable = None,
eval_split_max_size=None,
eval_split_size=0.01,
) -> Tuple[List[List], List[List]]: ) -> Tuple[List[List], List[List]]:
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
@ -52,9 +79,16 @@ def load_tts_samples(
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
must take the root_path and the meta_file name and return a list of samples in the format of must take the root_path and the meta_file name and return a list of samples in the format of
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
example. Defaults to None. example. Defaults to None.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
Returns: Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset. Tuple[List[List], List[List]: training and evaluation splits of the dataset.
""" """
@ -75,28 +109,28 @@ def load_tts_samples(
formatter = _get_formatter_by_name(name) formatter = _get_formatter_by_name(name)
# load train set # load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train] meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval] meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train meta_data_train_all += meta_data_train
# load attention masks for the duration predictor training # load attention masks for the duration predictor training
if dataset.meta_file_attn_mask: if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].update({"alignment_file": attn_file})
if meta_data_eval_all: if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all): for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_eval_all[idx].append(attn_file) meta_data_eval_all[idx].update({"alignment_file": attn_file})
# set none for the next iter # set none for the next iter
formatter = None formatter = None
return meta_data_train_all, meta_data_eval_all return meta_data_train_all, meta_data_eval_all

View File

@ -1,8 +1,7 @@
import collections import collections
import os import os
import random import random
from multiprocessing import Pool from typing import Dict, List, Union
from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
@ -10,87 +9,99 @@ import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")
def _parse_sample(item):
language_name = None
attn_file = None
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
elif len(item) == 4:
text, wav_file, speaker_name, language_name = item
elif len(item) == 3:
text, wav_file, speaker_name = item
else:
raise ValueError(" [!] Dataset cannot parse the sample.")
return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
class TTSDataset(Dataset): class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
outputs_per_step: int, outputs_per_step: int = 1,
text_cleaner: list, compute_linear_spec: bool = False,
compute_linear_spec: bool, ap: AudioProcessor = None,
ap: AudioProcessor, samples: List[Dict] = None,
meta_data: List[List], tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False, compute_f0: bool = False,
f0_cache_path: str = None, f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False, return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_seq_len: int = 0, min_text_len: int = 0,
max_seq_len: int = float("inf"), max_text_len: int = float("inf"),
use_phonemes: bool = False, min_audio_len: int = 0,
max_audio_len: int = float("inf"),
phoneme_cache_path: str = None, phoneme_cache_path: str = None,
phoneme_language: str = "en-us", precompute_num_workers: int = 0,
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None, speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None, d_vector_mapping: Dict = None,
language_id_mapping: Dict = None, language_id_mapping: Dict = None,
use_noise_augment: bool = False, use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False, verbose: bool = False,
): ):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can inherit and override. If you need something different, you can subclass and override.
Args: Args:
outputs_per_step (int): Number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True. compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object. ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances. samples (list): List of dataset samples.
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
use the given. Defaults to None.
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None. f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False. return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0. batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a Defaults to 0.
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
It helps for controlling the VRAM usage against long input sequences. Especially models with Defaults to float("inf").
RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
Defaults to 0.
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
this value if you encounter an OOM error in training. Defaults to float("inf").
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None. separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None. embedding layer. Defaults to None.
@ -99,285 +110,254 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false. verbose (bool): Print diagnostic information. Defaults to false.
""" """
super().__init__() super().__init__()
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
self.items = meta_data self._samples = samples
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len self.min_audio_len = min_audio_len
self.max_seq_len = max_seq_len self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1 self.rescue_item_idx = 1
self.pitch_computed = False self.pitch_computed = False
self.tokenizer = tokenizer
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
)
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0: if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose: if self.verbose:
print("\n > DataLoader initialization") self.print_logs()
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes: @property
print(" | > phoneme language: {}".format(phoneme_language)) def lengths(self):
print(" | > Number of instances : {}".format(len(self.items))) lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
lens.append(audio_len)
return lens
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, new_samples):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.load_data(idx)
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename) waveform = self.ap.load_wav(filename)
return audio assert waveform.size > 0
return waveform
def get_phonemes(self, idx, text):
out_dict = self.phoneme_dataset[idx]
assert text == out_dict["text"], f"{text} != {out_dict['text']}"
assert len(out_dict["token_ids"]) > 0
return out_dict
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"]
return out_dict
@staticmethod @staticmethod
def load_np(filename): def get_attn_mask(attn_file):
data = np.load(filename).astype("float32") return np.load(attn_file)
return data
@staticmethod def get_token_ids(self, idx, text):
def _generate_and_cache_phoneme_sequence( if self.tokenizer.use_phonemes:
text, cache_path, cleaners, language, custom_symbols, characters, add_blank token_ids = self.get_phonemes(idx, text)["token_ids"]
): else:
"""generate a phoneme sequence from text. token_ids = self.tokenizer.text_to_ids(text)
since the usage is for subsequent caching, we never add bos and return np.array(token_ids, dtype=np.int32)
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars.
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
def load_data(self, idx): def load_data(self, idx):
item = self.items[idx] item = self.samples[idx]
if len(item) == 5: raw_text = item["text"]
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
# apply noise for augmentation # apply noise for augmentation
if self.use_noise_augment: if self.use_noise_augment:
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) wav = noise_augment_audio(wav)
if not self.input_seq_computed: # get token ids
if self.use_phonemes: token_ids = self.get_token_ids(idx, item["text"])
text = self._load_or_generate_phoneme_sequence(
wav_file,
text,
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
language_name if language_name else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
assert text.size > 0, self.items[idx][1] # get pre-computed attention maps
assert wav.size > 0, self.items[idx][1] attn = None
if "alignment_file" in item:
attn = self.get_attn_mask(item["alignment_file"])
if "attn_file" in locals(): # after phonemization the text length may change
attn = np.load(attn_file) # this is a shareful 🤭 hack to prevent longer phonemes
# TODO: find a better fix
if len(text) > self.max_seq_len: if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
# return a different sample if the phonemized self.rescue_item_idx += 1
# text is longer than the threshold
# TODO: find a better fix
return self.load_data(self.rescue_item_idx) return self.load_data(self.rescue_item_idx)
pitch = None # get f0 values
f0 = None
if self.compute_f0: if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) f0 = self.get_f0(idx)["f0"]
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = { sample = {
"raw_text": raw_text, "raw_text": raw_text,
"text": text, "token_ids": token_ids,
"wav": wav, "wav": wav,
"pitch": pitch, "pitch": f0,
"attn": attn, "attn": attn,
"item_idx": self.items[idx][1], "item_idx": item["audio_file"],
"speaker_name": speaker_name, "speaker_name": item["speaker_name"],
"language_name": language_name, "language_name": item["language"],
"wav_file_name": os.path.basename(wav_file), "wav_file_name": os.path.basename(item["audio_file"]),
} }
return sample return sample
@staticmethod @staticmethod
def _phoneme_worker(args): def _compute_lengths(samples):
item = args[0] new_samples = []
func_args = args[1] for item in samples:
text, wav_file, *_ = item audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
func_args[3] = ( text_lenght = len(item["text"])
item[3] if item[3] else func_args[3] item["audio_length"] = audio_length
) # override phoneme language if specified by the dataset formatter item["text_length"] = text_lenght
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) new_samples += [item]
return phonemes return new_samples
def compute_input_seq(self, num_workers=0): @staticmethod
"""Compute the input sequences with multi-processing. def filter_by_length(lengths: List[int], min_len: int, max_len: int):
Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" idxs = np.argsort(lengths) # ascending order
if not self.use_phonemes: ignore_idx = []
if self.verbose: keep_idx = []
print(" | > Computing input sequences ...") for idx in idxs:
for idx, item in enumerate(tqdm.tqdm(self.items)): length = lengths[idx]
text, *_ = item if length < min_len or length > max_len:
sequence = np.asarray( ignore_idx.append(idx)
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
else: else:
with Pool(num_workers) as p: keep_idx.append(idx)
phonemes = list( return ignore_idx, keep_idx
tqdm.tqdm(
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_and_filter_items(self, by_audio_len=False): @staticmethod
def sort_by_length(samples: List[List]):
audio_lengths = [s["audio_length"] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order
return idxs
@staticmethod
def create_buckets(samples, batch_group_size: int):
assert batch_group_size > 0
for i in range(len(samples) // batch_group_size):
offset = i * batch_group_size
end_offset = offset + batch_group_size
temp_items = samples[offset:end_offset]
random.shuffle(temp_items)
samples[offset:end_offset] = temp_items
return samples
@staticmethod
def _select_samples_by_idx(idxs, samples):
samples_new = []
for idx in idxs:
samples_new.append(samples[idx])
return samples_new
def preprocess_samples(self):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range. range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
""" """
# compute the target sequence length samples = self._compute_lengths(self.samples)
if by_audio_len:
lengths = [] # sort items based on the sequence length in ascending order
for item in self.items: text_lengths = [i["text_length"] for i in samples]
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio audio_lengths = [i["audio_length"] for i in samples]
lengths = np.array(lengths) text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
else: audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
lengths = np.array([len(ins[0]) for ins in self.items]) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
samples = self._select_samples_by_idx(keep_idx, samples)
sorted_idxs = self.sort_by_length(samples)
if self.start_by_longest:
longest_idxs = sorted_idxs[-1]
sorted_idxs[-1] = sorted_idxs[0]
sorted_idxs[0] = longest_idxs
samples = self._select_samples_by_idx(sorted_idxs, samples)
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
idxs = np.argsort(lengths)
new_items = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups # shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0: if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size): samples = self.create_buckets(samples, self.batch_group_size)
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size # update items to the new sorted items
temp_items = new_items[offset:end_offset] audio_lengths = [s["audio_length"] for s in samples]
random.shuffle(temp_items) text_lengths = [s["text_length"] for s in samples]
new_items[offset:end_offset] = temp_items self.samples = samples
self.items = new_items
if self.verbose: if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Preprocessing samples")
print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Max text length: {}".format(np.max(text_lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths))) print(" | > Min text length: {}".format(np.min(text_lengths)))
print( print(" | > Avg text length: {}".format(np.mean(text_lengths)))
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( print(" | ")
self.max_seq_len, self.min_seq_len, len(ignored) print(" | > Max audio length: {}".format(np.max(audio_lengths)))
) print(" | > Min audio length: {}".format(np.min(audio_lengths)))
) print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size)) print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod @staticmethod
def _sort_batch(batch, text_lengths): def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency. """Sort the batch by the input text length for RNN efficiency.
@ -402,10 +382,10 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size # Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping): if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch]) token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
# sort items with text input length for RNN efficiency # sort items with text input length for RNN efficiency
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
# convert list of dicts to dict of lists # convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]} batch = {k: [dic[k] for dic in batch] for k in batch[0]}
@ -447,7 +427,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch # PAD sequences with longest instance in the batch
text = prepare_data(batch["text"]).astype(np.int32) token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
# PAD features with longest instance # PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step)
@ -456,12 +436,13 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1)
# convert things to pytorch # convert things to pytorch
text_lengths = torch.LongTensor(text_lengths) token_ids_lengths = torch.LongTensor(token_ids_lengths)
text = torch.LongTensor(text) token_ids = torch.LongTensor(token_ids)
mel = torch.FloatTensor(mel).contiguous() mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
# speaker vectors
if d_vectors is not None: if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors) d_vectors = torch.FloatTensor(d_vectors)
@ -472,14 +453,13 @@ class TTSDataset(Dataset):
language_ids = torch.LongTensor(language_ids) language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram # compute linear spectrogram
linear = None
if self.compute_linear_spec: if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1) linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1] assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous() linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
# format waveforms # format waveforms
wav_padded = None wav_padded = None
@ -495,8 +475,7 @@ class TTSDataset(Dataset):
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2) wav_padded.transpose_(1, 2)
# compute f0 # format F0
# TODO: compare perf in collate_fn vs in load_data
if self.compute_f0: if self.compute_f0:
pitch = prepare_data(batch["pitch"]) pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
@ -504,23 +483,22 @@ class TTSDataset(Dataset):
else: else:
pitch = None pitch = None
# collate attention alignments # format attention masks
attns = None
if batch["attn"][0] is not None: if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns): for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1] pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0] pad1 = token_ids.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]]) attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step) attns = prepare_tensor(attns, self.outputs_per_step)
attns = torch.FloatTensor(attns).unsqueeze(1) attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return { return {
"text": text, "token_id": token_ids,
"text_lengths": text_lengths, "token_id_lengths": token_ids_lengths,
"speaker_names": batch["speaker_name"], "speaker_names": batch["speaker_name"],
"linear": linear, "linear": linear,
"mel": mel, "mel": mel,
@ -546,22 +524,185 @@ class TTSDataset(Dataset):
) )
class PitchExtractor: class PhonemeDataset(Dataset):
"""Pitch Extractor for computing F0 from wav files. """Phoneme Dataset for converting input text to phonemes and then token IDs
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation.
Args: Args:
items (List[List]): Dataset samples. samples (Union[List[List], List[Dict]]):
verbose (bool): Whether to print the progress. List of samples. Each sample is a list or a dict.
tokenizer (TTSTokenizer):
Tokenizer to convert input text to phonemes.
cache_path (str):
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0.
""" """
def __init__( def __init__(
self, self,
items: List[List], samples: Union[List[Dict], List[List]],
verbose=False, tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers=0,
): ):
self.items = items self.samples = samples
self.tokenizer = tokenizer
self.cache_path = cache_path
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, wav_file, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
np.save(cache_path, ids)
return ids
def get_pad_id(self):
"""Get pad token ID for sequence padding"""
return self.tokenizer.pad_id
def precompute(self, num_workers=1):
"""Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy.
"""
print("[*] Pre-computing phonemes...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
for _ in dataloder:
pbar.update(batch_size)
def collate_fn(self, batch):
ids = [item["token_ids"] for item in batch]
ids_lens = [item["token_ids_len"] for item in batch]
texts = [item["text"] for item in batch]
texts_hat = [item["ph_hat"] for item in batch]
ids_lens_max = max(ids_lens)
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
for i, ids_len in enumerate(ids_lens):
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute F0 from wav files.
cache_path (str):
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the F0 values. Defaults to 0.
normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
self.mean = None self.mean = None
self.std = None self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_f0:
self.load_stats(cache_path)
def __getitem__(self, idx):
item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"])
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": item["audio_file"], "f0": f0}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
normalize_f0 = self.normalize_f0
self.normalize_f0 = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
f0 = batch["f0"]
computed_data.append(f for f in f0)
pbar.update(batch_size)
self.normalize_f0 = normalize_f0
if self.normalize_f0:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod @staticmethod
def create_pitch_file_path(wav_file, cache_path): def create_pitch_file_path(wav_file, cache_path):
@ -583,70 +724,49 @@ class PitchExtractor:
mean, std = np.mean(nonzeros), np.std(nonzeros) mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std return mean, std
def normalize_pitch(self, pitch): def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0] zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean pitch = pitch - self.mean
pitch = pitch / self.std pitch = pitch / self.std
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
def denormalize_pitch(self, pitch): def denormalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0] zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std pitch *= self.std
pitch += self.mean pitch += self.mean
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
@staticmethod def compute_or_load(self, wav_file):
def load_or_compute_pitch(ap, wav_file, cache_path):
""" """
compute pitch and return a numpy array of pitch values compute pitch and return a numpy array of pitch values
""" """
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else: else:
pitch = np.load(pitch_file) pitch = np.load(pitch_file)
return pitch.astype(np.float32) return pitch.astype(np.float32)
@staticmethod def collate_fn(self, batch):
def _pitch_worker(args): audio_file = [item["audio_file"] for item in batch]
item = args[0] f0s = [item["f0"] for item in batch]
ap = args[1] f0_lens = [len(item["f0"]) for item in batch]
cache_path = args[2] f0_lens_max = max(f0_lens)
_, wav_file, *_ = item f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) for i, f0_len in enumerate(f0_lens):
if not os.path.exists(pitch_file): f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
return pitch
return None
def compute_pitch(self, ap, cache_path, num_workers=0): def print_logs(self, level: int = 0) -> None:
"""Compute the input sequences with multi-processing. indent = "\t" * level
Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" print("\n")
if not os.path.exists(cache_path): print(f"{indent}> F0Dataset ")
os.makedirs(cache_path, exist_ok=True) print(f"{indent}| > Number of instances : {len(self.samples)}")
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)

View File

@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t") cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav") wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip() wav_file = cols[1].strip()
text = cols[0].strip() text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file) wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip() text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file) wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file): if os.path.isfile(wav_file):
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
else: else:
# M-AI-Labs have some missing samples, so just print the warning # M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file)) print("> File %s does not exist!" % (wav_file))
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
with open(txt_file, "r", encoding="utf-8") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
speaker_id = 0
for idx, line in enumerate(ttf): for idx, line in enumerate(ttf):
# 2 samples per speaker to avoid eval split issues
if idx % 2 == 0:
speaker_id += 1
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
return items return items
@ -150,7 +154,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file): if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...") print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue continue
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -165,7 +169,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -179,7 +183,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, cols[0]) wav_file = os.path.join(root_path, cols[0])
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -193,7 +197,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1] utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1] text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -213,7 +217,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
return items return items
@ -240,7 +244,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
items.append([text, wav_file, "LTTS_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
for item in items: for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items return items
@ -259,7 +263,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file) skipped_files.append(wav_file)
continue continue
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...") print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items return items
@ -281,12 +285,32 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append([text, wav_file, speaker_id]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
return items return items
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" """VCTK dataset v0.92.
URL:
https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```.
It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset.
mic1:
Audio recorded using an omni-directional microphone (DPA 4035).
Contains very low frequency noises.
This is the same audio released in previous versions of VCTK:
https://doi.org/10.7488/ds/1994
mic2:
Audio recorded using a small diaphragm condenser microphone with
very wide bandwidth (Sennheiser MKH 800).
Two speakers, p280 and p315 had technical issues of the audio
recordings using MKH 800.
"""
file_ext = "flac"
items = [] items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files: for meta_file in meta_files:
@ -298,26 +322,33 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
continue continue
with open(meta_file, "r", encoding="utf-8") as file_text: with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") # p280 has no mic2 recordings
items.append([text, wav_file, "VCTK_" + speaker_id]) if speaker_id == "p280":
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}")
else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file):
items.append([text, wav_file, "VCTK_" + speaker_id])
else:
print(f" [!] wav files don't exist - {wav_file}")
return items return items
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument def vctk_old(root_path, meta_files=None, wavs_path="wav48"):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = [] items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for text_file in txt_files: for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0] file_id = txt_file.split(".")[0]
# ignore speakers if isinstance(test_speakers, list): # if is list ignore this speakers ids
if isinstance(ignored_speakers, list): if speaker_id in test_speakers:
if speaker_id in ignored_speakers:
continue continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([None, wav_file, "VCTK_" + speaker_id]) items.append([text, wav_file, "VCTK_old_" + speaker_id])
return items return items
@ -334,7 +365,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker in ignored_speakers: if speaker in ignored_speakers:
continue continue
items.append([text, wav_file, "MLS_" + speaker]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
return items return items
@ -404,7 +435,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf: for line in ttf:
wav_name, text = line.rstrip("\n").split("|") wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name) wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name]) items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
return items return items
@ -418,5 +449,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "") text = cols[2].replace(" ", "")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items

View File

@ -113,7 +113,7 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2]) denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m ** 2) v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)

View File

@ -65,7 +65,7 @@ class WN(torch.nn.Module):
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers # intermediate layers
for i in range(num_layers): for i in range(num_layers):
dilation = dilation_rate ** i dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d( in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding

View File

@ -101,7 +101,7 @@ class Encoder(nn.Module):
self.encoder_type = encoder_type self.encoder_type = encoder_type
# embedding layer # embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder module # init encoder module
if encoder_type.lower() == "rel_pos_transformer": if encoder_type.lower() == "rel_pos_transformer":
if use_prenet: if use_prenet:

View File

@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# relative positional encoding layers # relative positional encoding layers
if rel_attn_window_size is not None: if rel_attn_window_size is not None:
n_heads_rel = 1 if heads_share else num_heads n_heads_rel = 1 if heads_share else num_heads
rel_stddev = self.k_channels ** -0.5 rel_stddev = self.k_channels**-0.5
emb_rel_k = nn.Parameter( emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
) )
@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]

View File

@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma): def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float() grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
@staticmethod @staticmethod
def _make_masks(ilens, olens): def _make_masks(ilens, olens):
@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module):
rl = rl.float().detach() rl = rl.float().detach()
gl = gl.float() gl = gl.float()
loss += torch.mean(torch.abs(rl - gl)) loss += torch.mean(torch.abs(rl - gl))
return loss * 2 return loss * 2
@staticmethod @staticmethod
@ -588,13 +587,12 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod @staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
def forward( def forward(
self, self,
waveform, mel_slice,
waveform_hat, mel_slice_hat,
z_p, z_p,
logs_q, logs_q,
m_p, m_p,
@ -610,8 +608,8 @@ class VitsGeneratorLoss(nn.Module):
): ):
""" """
Shapes: Shapes:
- waveform : :math:`[B, 1, T]` - mel_slice : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]` - mel_slice_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]` - z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]` - m_p: :math:`[B, C, T]`
@ -624,23 +622,23 @@ class VitsGeneratorLoss(nn.Module):
loss = 0.0 loss = 0.0
return_dict = {} return_dict = {}
z_mask = sequence_mask(z_len).float() z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses # compute losses
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_kl = (
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha )
loss_feat = (
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
)
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss: if use_speaker_encoder_as_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss += loss_se loss = loss + loss_se
return_dict["loss_spk_encoder"] = loss_se return_dict["loss_spk_encoder"] = loss_se
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl return_dict["loss_kl"] = loss_kl
@ -665,20 +663,24 @@ class VitsDiscriminatorLoss(nn.Module):
dr = dr.float() dr = dr.float()
dg = dg.float() dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2) real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2) fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss loss += real_loss + fake_loss
real_losses.append(real_loss.item()) real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item()) fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake): def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0 loss = 0.0
return_dict = {} return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"] loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss return_dict["loss"] = loss
for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_disc_real_{i}"] = ldr
return return_dict return return_dict
@ -740,6 +742,7 @@ class ForwardTTSLoss(nn.Module):
alignment_logprob=None, alignment_logprob=None,
alignment_hard=None, alignment_hard=None,
alignment_soft=None, alignment_soft=None,
binary_loss_weight=None,
): ):
loss = 0 loss = 0
return_dict = {} return_dict = {}
@ -772,7 +775,12 @@ class ForwardTTSLoss(nn.Module):
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight:
return_dict["loss_binary_alignment"] = (
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
)
else:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss return_dict["loss"] = loss
return return_dict return return_dict

View File

@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module):
# score = softmax(QK^T / (d_k ** 0.5)) # score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5) scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3) scores = F.softmax(scores, dim=3)
# out = score * V # out = score * V

View File

@ -6,7 +6,6 @@ from .attentions import init_attn
from .common_layers import Linear, Prenet from .common_layers import Linear, Prenet
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg # pylint: disable=unexpected-keyword-arg
class ConvBNBlock(nn.Module): class ConvBNBlock(nn.Module):

View File

@ -57,7 +57,7 @@ class TextEncoder(nn.Module):
self.emb = nn.Embedding(n_vocab, hidden_channels) self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
if language_emb_dim: if language_emb_dim:
hidden_channels += language_emb_dim hidden_channels += language_emb_dim
@ -83,6 +83,7 @@ class TextEncoder(nn.Module):
- x: :math:`[B, T]` - x: :math:`[B, T]`
- x_length: :math:`[B]` - x_length: :math:`[B]`
""" """
assert x.shape[0] == x_lengths.shape[0]
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
# concat the lang emb in embedding chars # concat the lang emb in embedding chars
@ -90,7 +91,7 @@ class TextEncoder(nn.Module):
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t] x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module):
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
""" """
Note:
Set `reverse` to True for inference.
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]` - x_mask: :math:`[B, 1, T]`
@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module):
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
""" """
Note:
Set `reverse` to True for inference.
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]` - x_mask: :math:`[B, 1, T]`

View File

@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module):
self.norms_1 = nn.ModuleList() self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList() self.norms_2 = nn.ModuleList()
for i in range(num_layers): for i in range(num_layers):
dilation = kernel_size ** i dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2 padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append( self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module):
# posterior encoder - neg log likelihood # posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = ( nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q
) )
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module):
z = torch.flip(z, [1]) z = torch.flip(z, [1])
# flow layers - neg log likelihood # flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))

View File

@ -1,52 +1,14 @@
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
print(" > Using model: {}".format(config.model)) print(" > Using model: {}".format(config.model))
# fetch the right model implementation. # fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None: if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower()) MyModel = find_module("TTS.tts.models", config.base_model.lower())
else: else:
MyModel = find_module("TTS.tts.models", config.model.lower()) MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model model = MyModel.init_from_config(config, samples)
if config.characters is not None:
# set characters from config
if hasattr(MyModel, "make_symbols"):
symbols = MyModel.make_symbols(config)
else:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
if config.use_phonemes:
symbols = phonemes
# use default characters and assign them to config
config.characters = parse_symbols()
# consider special `blank` character if `add_blank` is set True
num_chars = len(symbols) + getattr(config, "add_blank", False)
config.num_chars = num_chars
# compatibility fix
if "model_params" in config:
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)
return model return model
# TODO; class registery
# def import_models(models_dir, namespace):
# for file in os.listdir(models_dir):
# path = os.path.join(models_dir, file)
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
# model_name = file[: file.find(".py")] if file.endswith(".py") else file
# importlib.import_module(namespace + "." + model_name)
#
#
## automatically import any Python files in the models/ directory
# models_dir = os.path.dirname(__file__)
# import_models(models_dir, "TTS.tts.models")

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -100,11 +102,16 @@ class AlignTTS(BaseTTS):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
self,
config: "AlignTTSConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config) super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.config = config
self.phase = -1 self.phase = -1
self.length_scale = ( self.length_scale = (
float(config.model_args.length_scale) float(config.model_args.length_scale)
@ -112,10 +119,6 @@ class AlignTTS(BaseTTS):
else config.model_args.length_scale else config.model_args.length_scale
) )
if not self.config.model_args.num_chars:
_, self.config, num_chars = self.get_characters(config)
self.config.model_args.num_chars = num_chars
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
self.embedded_speaker_dim = 0 self.embedded_speaker_dim = 0
@ -382,19 +385,17 @@ class AlignTTS(BaseTTS):
def train_log( def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
@ -430,3 +431,19 @@ class AlignTTS(BaseTTS):
def on_epoch_start(self, trainer): def on_epoch_start(self, trainer):
"""Set AlignTTS training phase on epoch start.""" """Set AlignTTS training phase on epoch start."""
self.phase = self._set_phase(trainer.config, trainer.total_steps_done) self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
@staticmethod
def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (AlignTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return AlignTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -9,6 +9,8 @@ from torch import nn
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS): class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2""" """Base class shared by Tacotron and Tacotron2"""
def __init__(self, config: Coqpit): def __init__(
super().__init__(config) self,
config: "TacotronConfig",
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields as class attributes # pass all config fields as class attributes
for key in config: for key in config:
@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS):
"""Get the model criterion used in training.""" """Get the model criterion used in training."""
return TacotronLoss(self.config) return TacotronLoss(self.config)
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
############################# #############################
# COMMON COMPUTE FUNCTIONS # COMMON COMPUTE FUNCTIONS
############################# #############################

View File

@ -1,6 +1,6 @@
import os import os
import random import random
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -9,33 +9,44 @@ from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseTrainerModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file # pylint: skip-file
class BaseTTS(BaseModel): class BaseTTS(BaseTrainerModel):
"""Base `tts` class. Every new `tts` model must inherit this. """Base `tts` class. Every new `tts` model must inherit this.
It defines common `tts` specific functions on top of `Model` implementation. It defines common `tts` specific functions on top of `Model` implementation.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
""" """
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__()
self.config = config
self.ap = ap
self.tokenizer = tokenizer
self.speaker_manager = speaker_manager
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit): def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type. """Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
If the config is for training with a name like "*Config", then the model args are embeded in the If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args config.model_args
@ -44,8 +55,11 @@ class BaseTTS(BaseModel):
""" """
# don't use isintance not to import recursively # don't use isintance not to import recursively
if "Config" in config.__class__.__name__: if "Config" in config.__class__.__name__:
config_num_chars = (
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
)
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config: if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars self.config.num_chars = num_chars
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars config.model_args.num_chars = num_chars
@ -58,22 +72,6 @@ class BaseTTS(BaseModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = CharactersConfig(**parse_symbols())
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers. `in_channels` size of the connected layers.
@ -170,8 +168,8 @@ class BaseTTS(BaseModel):
Dict: [description] Dict: [description]
""" """
# setup input batch # setup input batch
text_input = batch["text"] text_input = batch["token_id"]
text_lengths = batch["text_lengths"] text_lengths = batch["token_id_lengths"]
speaker_names = batch["speaker_names"] speaker_names = batch["speaker_names"]
linear_input = batch["linear"] linear_input = batch["linear"]
mel_input = batch["mel"] mel_input = batch["mel"]
@ -239,7 +237,7 @@ class BaseTTS(BaseModel):
config: Coqpit, config: Coqpit,
assets: Dict, assets: Dict,
is_eval: bool, is_eval: bool,
data_items: List, samples: Union[List[Dict], List[List]],
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = None, rank: int = None,
@ -247,8 +245,6 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
else: else:
ap = assets["audio_processor"]
# setup multi-speaker attributes # setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"): if hasattr(config, "model_args"):
@ -264,12 +260,8 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None speaker_id_mapping = None
d_vector_mapping = None d_vector_mapping = None
# setup custom symbols if needed # setup multi-lingual attributes
custom_symbols = None if hasattr(self, "language_manager") and self.language_manager is not None:
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
if hasattr(self, "language_manager"):
language_id_mapping = ( language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None self.language_manager.language_id_mapping if self.args.use_language_embedding else None
) )
@ -279,74 +271,40 @@ class BaseTTS(BaseModel):
# init dataloader # init dataloader
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1, outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False), compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None), f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items, samples=samples,
ap=ap, ap=self.ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
return_wav=config.return_wav if "return_wav" in config else False, return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len, min_text_len=config.min_text_len,
max_seq_len=config.max_seq_len, max_text_len=config.max_text_len,
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
use_phonemes=config.use_phonemes, precompute_num_workers=config.precompute_num_workers,
phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars,
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping, language_id_mapping=language_id_mapping,
) )
# pre-compute phonemes # wait all the DDP process to be ready
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
# halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1: if num_gpus > 1:
dist.barrier() dist.barrier()
# sort input sequences from short to long # sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) dataset.preprocess_samples()
# compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP # sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers # Weighted samplers
# TODO: make this DDP amenable
assert not ( assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler" ), "language_weighted_sampler is not supported with DistributedSampler"
@ -357,17 +315,17 @@ class BaseTTS(BaseModel):
if sampler is None: if sampler is None:
if getattr(config, "use_language_weighted_sampler", False): if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler") print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items) sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False): elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler") print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items) sampler = get_speaker_weighted_sampler(dataset.samples)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size, batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, shuffle=False, # shuffle is done in the dataset.
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler, sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False, pin_memory=False,
@ -403,7 +361,6 @@ class BaseTTS(BaseModel):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -415,17 +372,15 @@ class BaseTTS(BaseModel):
sen, sen,
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
) )
test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram( test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
) )
test_figures["{}-alignment".format(idx)] = plot_alignment( test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False outputs_dict["outputs"]["alignments"], output_fig=False

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Tuple from typing import Dict, List, Tuple, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -14,7 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
@dataclass @dataclass
@ -170,17 +171,22 @@ class ForwardTTS(BaseTTS):
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
self,
config: Coqpit,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
self._set_model_args(config)
super().__init__(config)
self.speaker_manager = speaker_manager
self.init_multispeaker(config) self.init_multispeaker(config)
self.max_duration = self.args.max_duration self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch self.use_pitch = self.args.use_pitch
self.use_binary_alignment_loss = False self.binary_loss_weight = 0.0
self.length_scale = ( self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
@ -255,7 +261,7 @@ class ForwardTTS(BaseTTS):
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") print(" > Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels) self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod @staticmethod
@ -638,8 +644,9 @@ class ForwardTTS(BaseTTS):
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
input_lens=text_lengths, input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, alignment_soft=outputs["alignment_soft"],
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, alignment_hard=outputs["alignment_mas"],
binary_loss_weight=self.binary_loss_weight,
) )
# compute duration error # compute duration error
durations_pred = outputs["durations"] durations_pred = outputs["durations"]
@ -666,17 +673,12 @@ class ForwardTTS(BaseTTS):
# plot pitch figures # plot pitch figures
if self.args.use_pitch: if self.args.use_pitch:
pitch = batch["pitch"] pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
pitch_avg_expanded, _ = self.expand_encoder_outputs( pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
)
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
pitch_figures = { pitch_figures = {
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
} }
figures.update(pitch_figures) figures.update(pitch_figures)
@ -692,19 +694,17 @@ class ForwardTTS(BaseTTS):
def train_log( def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
@ -721,6 +721,21 @@ class ForwardTTS(BaseTTS):
return ForwardTTSLoss(self.config) return ForwardTTSLoss(self.config)
def on_train_step_start(self, trainer): def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed""" """Schedule binary loss weight."""
if trainer.total_steps_done > self.config.binary_align_loss_start_step: self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
self.use_binary_alignment_loss = True
@staticmethod
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (ForwardTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,5 +1,5 @@
import math import math
from typing import Dict, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -39,18 +40,31 @@ class GlowTTS(BaseTTS):
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples: Examples:
Init only model layers.
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig(num_chars=2)
>>> model = GlowTTS(config)
Fully init a model ready for action. All the class attributes and class members
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS >>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig() >>> config = GlowTTSConfig()
>>> model = GlowTTS(config) >>> model = GlowTTS.init_from_config(config, verbose=False)
""" """
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): def __init__(
self,
config: GlowTTSConfig,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config) super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
@ -58,7 +72,6 @@ class GlowTTS(BaseTTS):
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
_, self.config, self.num_chars = self.get_characters(config)
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary # init multi-speaker layers if necessary
@ -94,25 +107,25 @@ class GlowTTS(BaseTTS):
def init_multispeaker(self, config: Coqpit): def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets
speaker embedding vector dimension to the d-vector dimension from the config.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
""" """
self.embedded_speaker_dim = 0 self.embedded_speaker_dim = 0
# init speaker manager
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if self.speaker_manager is not None: if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size # set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file: if config.use_d_vector_file:
self.embedded_speaker_dim = ( self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
) )
if self.speaker_manager is not None:
assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") print(" > Init speaker_embedding layer.")
@ -170,6 +183,8 @@ class GlowTTS(BaseTTS):
if g is not None: if g is not None:
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
# use speaker embedding layer # use speaker embedding layer
if not g.size(): # if is a scalar
g = g.unsqueeze(0) # unsqueeze
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
else: else:
# use d-vector # use d-vector
@ -180,12 +195,33 @@ class GlowTTS(BaseTTS):
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Shapes: Args:
- x: :math:`[B, T]` x (torch.Tensor):
- x_lenghts::math:`B` Input text sequence ids. :math:`[B, T_en]`
- y: :math:`[B, T, C]`
- y_lengths::math:`B` x_lengths (torch.Tensor):
- g: :math:`[B, C] or B` Lengths of input text sequences. :math:`[B]`
y (torch.Tensor):
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
y_lengths (torch.Tensor):
Lengths of target mel-spectrogram frames. :math:`[B]`
aux_input (Dict):
Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model.
:math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding
layer. :math:`B`
Returns:
Dict:
- z: :math: `[B, T_de, C]`
- logdet: :math:`B`
- y_mean: :math:`[B, T_de, C]`
- y_log_scale: :math:`[B, T_de, C]`
- alignments: :math:`[B, T_en, T_de]`
- durations_log: :math:`[B, T_en, 1]`
- total_durations_log: :math:`[B, T_en, 1]`
""" """
# [B, T, C] -> [B, C, T] # [B, T, C] -> [B, C, T]
y = y.transpose(1, 2) y = y.transpose(1, 2)
@ -206,9 +242,9 @@ class GlowTTS(BaseTTS):
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
@ -255,9 +291,9 @@ class GlowTTS(BaseTTS):
# find the alignment path between z and encoder output # find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
@ -422,20 +458,18 @@ class GlowTTS(BaseTTS):
def train_log( def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad() @torch.no_grad()
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
@ -446,7 +480,6 @@ class GlowTTS(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -461,18 +494,16 @@ class GlowTTS(BaseTTS):
sen, sen,
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
) )
test_audios["{}-audio".format(idx)] = outputs["wav"] test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram( test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs["outputs"]["model_outputs"], ap, output_fig=False outputs["outputs"]["model_outputs"], self.ap, output_fig=False
) )
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios return test_figures, test_audios
@ -499,7 +530,8 @@ class GlowTTS(BaseTTS):
self.store_inverse() self.store_inverse()
assert not self.training assert not self.training
def get_criterion(self): @staticmethod
def get_criterion():
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
return GlowTTSLoss() return GlowTTSLoss()
@ -507,3 +539,20 @@ class GlowTTS(BaseTTS):
def on_train_step_start(self, trainer): def on_train_step_start(self, trainer):
"""Decide on every training step wheter enable/disable data depended initialization.""" """Decide on every training step wheter enable/disable data depended initialization."""
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,7 +1,8 @@
# coding: utf-8 # coding: utf-8
from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -24,12 +26,15 @@ class Tacotron(BaseTacotron):
a multi-speaker model. Defaults to None. a multi-speaker model. Defaults to None.
""" """
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
super().__init__(config) self,
config: "TacotronConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
self.speaker_manager = speaker_manager super().__init__(config, ap, tokenizer, speaker_manager)
chars, self.config, _ = self.get_characters(config)
config.num_chars = self.num_chars = len(chars)
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
@ -302,16 +307,30 @@ class Tacotron(BaseTacotron):
def train_log( def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (TacotronConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return Tacotron(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,9 +1,8 @@
# coding: utf-8 # coding: utf-8
from typing import Dict from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
""" """
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
super().__init__(config) self,
config: "Tacotron2Config",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
# pass all config fields to `self` # pass all config fields to `self`
@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron):
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
"""Log training progress.""" """Log training progress."""
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (Tacotron2Config): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(new_config, samples)
return Tacotron2(new_config, ap, tokenizer, speaker_manager)

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +0,0 @@
## Utilities to Convert Models to Tensorflow2
Here there are experimental utilities to convert trained Torch models to Tensorflow (2.2>=).
Converting Torch models to TF enables all the TF toolkit to be used for better deployment and device specific optimizations.
Note that we do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome.
To see how you can use TF model at inference, check the notebook.
This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own.
### Converting a Model
- Run ```convert_tacotron2_torch_to_tf.py --torch_model_path /path/to/torch/model.pth.tar --config_path /path/to/model/config.json --output_path /path/to/output/tf/model``` with the right arguments.
### Known issues ans limitations
- We use a custom model load/save mechanism which enables us to store model related information with models weights. (Similar to Torch). However, it is prone to random errors.
- Current TF model implementation is slightly slower than Torch model. Hopefully, it'll get better with improving TF support for eager mode and ```tf.function```.
- TF implementation of Tacotron2 only supports regular Tacotron2 as in the paper.
- You can only convert models trained after TF model implementation since model layers has been updated in Torch model.

View File

@ -1,301 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class Linear(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.activation = keras.layers.ReLU()
def call(self, x):
"""
shapes:
x: B x T x C
"""
return self.activation(self.linear_layer(x))
class LinearBN(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.batch_normalization = keras.layers.BatchNormalization(
axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.activation = keras.layers.ReLU()
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
out = self.linear_layer(x)
out = self.batch_normalization(out, training=training)
return self.activation(out)
class Prenet(keras.layers.Layer):
def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs):
super().__init__(**kwargs)
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.linear_layers = []
if prenet_type == "bn":
self.linear_layers += [
LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
elif prenet_type == "original":
self.linear_layers += [
Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
else:
raise RuntimeError(" [!] Unknown prenet type.")
if prenet_dropout:
self.dropout = keras.layers.Dropout(rate=0.5)
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
for linear in self.linear_layers:
if self.prenet_dropout:
x = self.dropout(linear(x), training=training)
else:
x = linear(x)
return x
def _sigmoid_norm(score):
attn_weights = tf.nn.sigmoid(score)
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
return attn_weights
class Attention(keras.layers.Layer):
"""TODO: implement forward_attention
TODO: location sensitive attention
TODO: implement attention windowing"""
def __init__(
self,
attn_dim,
use_loc_attn,
loc_attn_n_filters,
loc_attn_kernel_size,
use_windowing,
norm,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
**kwargs,
):
super().__init__(**kwargs)
self.use_loc_attn = use_loc_attn
self.loc_attn_n_filters = loc_attn_n_filters
self.loc_attn_kernel_size = loc_attn_kernel_size
self.use_windowing = use_windowing
self.norm = norm
self.use_forward_attn = use_forward_attn
self.use_trans_agent = use_trans_agent
self.use_forward_attn_mask = use_forward_attn_mask
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer")
self.inputs_layer = tf.keras.layers.Dense(
attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer"
)
self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer")
if use_loc_attn:
self.location_conv1d = keras.layers.Conv1D(
filters=loc_attn_n_filters,
kernel_size=loc_attn_kernel_size,
padding="same",
use_bias=False,
name="location_layer/location_conv1d",
)
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense")
if norm == "softmax":
self.norm_func = tf.nn.softmax
elif norm == "sigmoid":
self.norm_func = _sigmoid_norm
else:
raise ValueError("Unknown value for attention norm type")
def init_states(self, batch_size, value_length):
states = []
if self.use_loc_attn:
attention_cum = tf.zeros([batch_size, value_length])
attention_old = tf.zeros([batch_size, value_length])
states = [attention_cum, attention_old]
if self.use_forward_attn:
alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1)
states.append(alpha)
return tuple(states)
def process_values(self, values):
"""cache values for decoder iterations"""
# pylint: disable=attribute-defined-outside-init
self.processed_values = self.inputs_layer(values)
self.values = values
def get_loc_attn(self, query, states):
"""compute location attention, query layer and
unnorm. attention weights"""
attention_cum, attention_old = states[:2]
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1))
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn))
score = tf.squeeze(score, axis=2)
return score, processed_query
def get_attn(self, query):
"""compute query layer and unnormalized attention weights"""
processed_query = self.query_layer(tf.expand_dims(query, 1))
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
score = tf.squeeze(score, axis=2)
return score, processed_query
def apply_score_masking(self, score, mask): # pylint: disable=no-self-use
"""ignore sequence paddings"""
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# Bias so padding positions do not contribute to attention distribution.
score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use
# forward attention
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
# compute transition potentials
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
# renormalize attention weights
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
return new_alpha
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
states = []
if self.use_loc_attn:
states = [old_states[0] + scores_norm, attn_weights]
if self.use_forward_attn:
states.append(new_alpha)
return tuple(states)
def call(self, query, states):
"""
shapes:
query: B x D
"""
if self.use_loc_attn:
score, _ = self.get_loc_attn(query, states)
else:
score, _ = self.get_attn(query)
# TODO: masking
# if mask is not None:
# self.apply_score_masking(score, mask)
# attn_weights shape == (batch_size, max_length, 1)
# normalize attention scores
scores_norm = self.norm_func(score)
attn_weights = scores_norm
# apply forward attention
new_alpha = None
if self.use_forward_attn:
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
attn_weights = new_alpha
# update states tuple
# states = (cum_attn_weights, attn_weights, new_alpha)
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = tf.matmul(
tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False
)
context_vector = tf.squeeze(context_vector, axis=1)
return context_vector, attn_weights, states
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
# dtype = processed_query.dtype
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
# class LocationSensitiveAttention(BahdanauAttention):
# def __init__(self,
# units,
# memory=None,
# memory_sequence_length=None,
# normalize=False,
# probability_fn="softmax",
# kernel_initializer="glorot_uniform",
# dtype=None,
# name="LocationSensitiveAttention",
# location_attention_filters=32,
# location_attention_kernel_size=31):
# super( self).__init__(units=units,
# memory=memory,
# memory_sequence_length=memory_sequence_length,
# normalize=normalize,
# probability_fn='softmax', ## parent module default
# kernel_initializer=kernel_initializer,
# dtype=dtype,
# name=name)
# if probability_fn == 'sigmoid':
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
# self.location_dense = keras.layers.Dense(units, use_bias=False)
# # self.v = keras.layers.Dense(1, use_bias=True)
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
# processed_query = tf.expand_dims(processed_query, 1)
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
# def _location_sensitive(self, alignment_cum, alignment_old):
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
# return self.location_dense(self.location_conv(alignment_cat))
# def _sigmoid_normalization(self, score):
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
# # def _apply_masking(self, score, mask):
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# # # Bias so padding positions do not contribute to attention distribution.
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
# # return score
# def _calculate_attention(self, query, state):
# alignment_cum, alignment_old = state[:2]
# processed_query = self.query_layer(
# query) if self.query_layer else query
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
# score = self._location_sensitive_score(
# processed_query,
# self.keys,
# processed_loc)
# alignment = self.probability_fn(score, state)
# alignment_cum = alignment_cum + alignment
# state[0] = alignment_cum
# state[1] = alignment
# return alignment, state
# def compute_context(self, alignments):
# expanded_alignments = tf.expand_dims(alignments, 1)
# context = tf.matmul(expanded_alignments, self.values)
# context = tf.squeeze(context, [1])
# return context
# # def call(self, query, state):
# # alignment, next_state = self._calculate_attention(query, state)
# # return alignment, next_state

View File

@ -1,322 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet
from TTS.tts.tf.utils.tf_utils import shape_list
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class ConvBNBlock(keras.layers.Layer):
def __init__(self, filters, kernel_size, activation, **kwargs):
super().__init__(**kwargs)
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d")
self.batch_normalization = keras.layers.BatchNormalization(
axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.dropout = keras.layers.Dropout(rate=0.5, name="dropout")
self.activation = keras.layers.Activation(activation, name="activation")
def call(self, x, training=None):
o = self.convolution1d(x)
o = self.batch_normalization(o, training=training)
o = self.activation(o)
o = self.dropout(o, training=training)
return o
class Postnet(keras.layers.Layer):
def __init__(self, output_filters, num_convs, **kwargs):
super().__init__(**kwargs)
self.convolutions = []
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0"))
for idx in range(1, num_convs - 1):
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}"))
self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}"))
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
return o
class Encoder(keras.layers.Layer):
def __init__(self, output_input_dim, **kwargs):
super().__init__(**kwargs)
self.convolutions = []
for idx in range(3):
self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}"))
self.lstm = keras.layers.Bidirectional(
keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm"
)
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
o = self.lstm(o)
return o
class Decoder(keras.layers.Layer):
# pylint: disable=unused-argument
def __init__(
self,
frame_dim,
r,
attn_type,
use_attn_win,
attn_norm,
prenet_type,
prenet_dropout,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
use_location_attn,
attn_K,
separate_stopnet,
speaker_emb_dim,
enable_tflite,
**kwargs,
):
super().__init__(**kwargs)
self.frame_dim = frame_dim
self.r_init = tf.constant(r, dtype=tf.int32)
self.r = tf.constant(r, dtype=tf.int32)
self.output_dim = r * self.frame_dim
self.separate_stopnet = separate_stopnet
self.enable_tflite = enable_tflite
# layer constants
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
# model dimensions
self.query_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.attn_dim = 128
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet")
self.attention_rnn = keras.layers.LSTMCell(
self.query_dim,
use_bias=True,
name="attention_rnn",
)
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
self.attention = Attention(
attn_dim=self.attn_dim,
use_loc_attn=True,
loc_attn_n_filters=32,
loc_attn_kernel_size=31,
use_windowing=False,
norm=attn_norm,
use_forward_attn=use_forward_attn,
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name="attention",
)
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn")
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer")
self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer")
def set_max_decoder_steps(self, new_max_steps):
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
def set_r(self, new_r):
self.r = tf.constant(new_r, dtype=tf.int32)
self.output_dim = self.frame_dim * new_r
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
zero_frame = tf.zeros([batch_size, self.frame_dim])
zero_context = tf.zeros([batch_size, memory_dim])
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
attention_states = self.attention.init_states(batch_size, memory_length)
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
def step(self, prenet_next, states, memory_seq_length=None, training=None):
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
attention_rnn_output, attention_rnn_state = self.attention_rnn(
attention_rnn_input, attention_rnn_state, training=training
)
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
decoder_rnn_output, decoder_rnn_state = self.decoder_rnn(
decoder_rnn_input, decoder_rnn_state, training=training
)
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
output_frame = self.linear_projection(linear_projection_input, training=training)
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
stopnet_output = self.stopnet(stopnet_input, training=training)
output_frame = output_frame[:, : self.r * self.frame_dim]
states = (
output_frame[:, self.frame_dim * (self.r - 1) :],
context,
attention_rnn_state,
decoder_rnn_state,
attention_states,
)
return output_frame, stopnet_output, states, attention
def decode(self, memory, states, frames, memory_seq_length=None):
B, _, _ = shape_list(memory)
num_iter = shape_list(frames)[1] // self.r
# init states
frame_zero = tf.expand_dims(states[0], 1)
frames = tf.concat([frame_zero, frames], axis=1)
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
# pre-computes
self.attention.process_values(memory)
prenet_output = self.prenet(frames, training=True)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
prenet_next = prenet_output[:, step]
output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
_, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
lambda *arg: True,
_body,
loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=num_iter,
)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def decode_inference(self, memory, states):
B, _, _ = shape_list(memory)
# init states
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
# pre-computes
self.attention.process_values(memory)
# iter vars
stop_flag = tf.constant(False, dtype=tf.bool)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, attention = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
_, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def decode_inference_tflite(self, memory, states):
"""Inference with TF-Lite compatibility. It assumes
batch_size is 1"""
# init states
# dynamic_shape is not supported in TFLite
outputs = tf.TensorArray(
dtype=tf.float32,
size=self.max_decoder_steps,
element_shape=tf.TensorShape([self.output_dim]),
clear_after_read=False,
dynamic_size=False,
)
# stop_flags = tf.TensorArray(dtype=tf.bool,
# size=self.max_decoder_steps,
# element_shape=tf.TensorShape(
# []),
# clear_after_read=False,
# dynamic_size=False)
attentions = ()
stop_tokens = ()
# pre-computes
self.attention.process_values(memory)
# iter vars
stop_flag = tf.constant(False, dtype=tf.bool)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, states, outputs, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, _ = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
outputs = outputs.write(step, tf.reshape(output, [-1]))
return step + 1, memory, states, outputs, stop_flag
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
step_count, memory, states, outputs, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
outputs = tf.expand_dims(outputs, axis=[0])
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
if training:
return self.decode(memory, states, frames, memory_seq_length)
if self.enable_tflite:
return self.decode_inference_tflite(memory, states)
return self.decode_inference(memory, states)

View File

@ -1,116 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.tf.utils.tf_utils import shape_list
# pylint: disable=too-many-ancestors, abstract-method
class Tacotron2(keras.models.Model):
def __init__(
self,
num_chars,
num_speakers,
r,
out_channels=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False,
enable_tflite=False,
):
super().__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.out_channels = out_channels
self.bidirectional_decoder = bidirectional_decoder
self.num_speakers = num_speakers
self.speaker_embed_dim = 256
self.enable_tflite = enable_tflite
self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding")
self.encoder = Encoder(512, name="encoder")
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(
decoder_output_dim,
r,
attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim,
name="decoder",
enable_tflite=enable_tflite,
)
self.postnet = Postnet(out_channels, 5, name="postnet")
@tf.function(experimental_relax_shapes=True)
def call(self, characters, text_lengths=None, frames=None, training=None):
if training:
return self.training(characters, text_lengths, frames)
if not training:
return self.inference(characters)
raise RuntimeError(" [!] Set model training mode True or False")
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(
encoder_output, decoder_states, frames, text_lengths, training=True
)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
def inference(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
@tf.function(
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec([1, None], dtype=tf.int32),
],
)
def inference_tflite(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(
self,
):
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg
self(input_ids)

View File

@ -1,87 +0,0 @@
import numpy as np
import tensorflow as tf
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
def tf_create_dummy_inputs():
"""Create dummy inputs for TF Tacotron2 model"""
batch_size = 4
max_input_length = 32
max_mel_length = 128
pad = 1
n_chars = 24
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size])
input_lengths[-1] = max_input_length
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size])
mel_lengths[-1] = max_mel_length
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
return input_ids, input_lengths, mel_outputs, mel_lengths
def compare_torch_tf(torch_tensor, tf_tensor):
"""Compute the average absolute difference b/w torch and tf tensors"""
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
def convert_tf_name(tf_name):
"""Convert certain patterns in TF layer names to Torch patterns"""
tf_name_tmp = tf_name
tf_name_tmp = tf_name_tmp.replace(":0", "")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1")
tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh")
tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight")
tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight")
tf_name_tmp = tf_name_tmp.replace("/beta", "/bias")
tf_name_tmp = tf_name_tmp.replace("/", ".")
return tf_name_tmp
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
"""Transfer weigths from torch state_dict to TF variables"""
print(" > Passing weights from Torch to TF ...")
for tf_var in tf_vars:
torch_var_name = var_map_dict[tf_var.name]
print(f" | > {tf_var.name} <-- {torch_var_name}")
# if tuple, it is a bias variable
if not isinstance(torch_var_name, tuple):
torch_layer_name = ".".join(torch_var_name.split(".")[-2:])
torch_weight = state_dict[torch_var_name]
if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name:
# out_dim, in_dim, filter -> filter, in_dim, out_dim
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
# if variable is for bidirectional lstm and it is a bias vector there
# needs to be pre-defined two matching torch bias vectors
elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif "rnn" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
elif "rnn" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif "linear_layer" in torch_layer_name and "weight" in torch_var_name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
else:
numpy_weight = torch_weight.detach().cpu().numpy()
assert np.all(
tf_var.shape == numpy_weight.shape
), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
tf.keras.backend.set_value(tf_var, numpy_weight)
return tf_vars
def load_tf_vars(model_tf, tf_vars):
for tf_var in tf_vars:
model_tf.get_layer(tf_var.name).set_weights(tf_var)
return model_tf

View File

@ -1,105 +0,0 @@
import datetime
import importlib
import pickle
import fsspec
import numpy as np
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.max()
batch_size = sequence_length.size(0)
seq_range = np.empty([0, max_len], dtype=np.int8)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = seq_range_expand.type_as(sequence_length)
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
# B x T_max
return seq_range_expand < seq_length_expand
# @tf.custom_gradient
def check_gradient(x, grad_clip):
x_normed = tf.clip_by_norm(x, grad_clip)
grad_norm = tf.norm(grad_clip)
return x_normed, grad_norm
def count_parameters(model, c):
try:
return model.count_params()
except RuntimeError:
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
input_lengths = np.random.randint(100, 129, (8,))
input_lengths[-1] = 128
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
mel_spec = tf.convert_to_tensor(mel_spec)
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
return model.count_params()
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
raise NotImplementedError(" [!] Tacotron model is not ready.")
# tacotron2
model = MyModel(
num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
enable_tflite=enable_tflite,
)
return model

View File

@ -1,45 +0,0 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
def load_tflite_model(tflite_path):
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
tflite_model.allocate_tensors()
return tflite_model

View File

@ -1,8 +0,0 @@
import tensorflow as tf
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]

View File

@ -1,27 +0,0 @@
import fsspec
import tensorflow as tf
def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True):
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
provided, else return TFLite model."""
concrete_function = model.inference_tflite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
converter.experimental_new_converter = experimental_converter
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model
def load_tflite_model(tflite_path):
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
tflite_model.allocate_tensors()
return tflite_model

View File

@ -57,40 +57,65 @@ def sequence_mask(sequence_length, max_len=None):
return mask return mask
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
"""Segment each sample in a batch based on the provided segment indices """Segment each sample in a batch based on the provided segment indices
Args: Args:
x (torch.tensor): Input tensor. x (torch.tensor): Input tensor.
segment_indices (torch.tensor): Segment indices. segment_indices (torch.tensor): Segment indices.
segment_size (int): Expected output segment size. segment_size (int): Expected output segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
""" """
# pad the input tensor if it is shorter than the segment size
if pad_short and x.shape[-1] < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
segments = torch.zeros_like(x[:, :, :segment_size]) segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)): for i in range(x.size(0)):
index_start = segment_indices[i] index_start = segment_indices[i]
index_end = index_start + segment_size index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end] x_i = x[i]
if pad_short and index_end > x.size(2):
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
return segments return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4): def rand_segments(
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
):
"""Create random segments based on the input lengths. """Create random segments based on the input lengths.
Args: Args:
x (torch.tensor): Input tensor. x (torch.tensor): Input tensor.
x_lengths (torch.tensor): Input lengths. x_lengths (torch.tensor): Input lengths.
segment_size (int): Expected output segment size. segment_size (int): Expected output segment size.
let_short_samples (bool): Allow shorter samples than the segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_lengths: :math:`[B]` - x_lengths: :math:`[B]`
""" """
_x_lenghts = x_lengths.clone()
B, _, T = x.size() B, _, T = x.size()
if x_lengths is None: if pad_short:
x_lengths = T if T < segment_size:
max_idxs = x_lengths - segment_size + 1 x = torch.nn.functional.pad(x, (0, segment_size - T))
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." T = segment_size
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() if _x_lenghts is None:
_x_lenghts = T
len_diff = _x_lenghts - segment_size + 1
if let_short_samples:
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
else:
assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size) ret = segment(x, segment_indices, segment_size)
return ret, segment_indices return ret, segment_indices

View File

@ -8,6 +8,8 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import check_config_and_model_args
class LanguageManager: class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
@ -98,6 +100,20 @@ class LanguageManager:
""" """
self._save_json(file_path, self.language_id_mapping) self._save_json(file_path, self.language_id_mapping)
@staticmethod
def init_from_config(config: Coqpit) -> "LanguageManager":
"""Initialize the language manager from a Coqpit config.
Args:
config (Coqpit): Coqpit config.
"""
language_manager = None
if check_config_and_model_args(config, "use_language_embedding", True):
if config.get("language_ids_file", None):
language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
language_manager = LanguageManager(config=config)
return language_manager
def _set_file_path(path): def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it. """Find the language_ids.json under the given path or the above it.
@ -113,7 +129,7 @@ def _set_file_path(path):
def get_language_weighted_sampler(items: list): def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items]) language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist() unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names] language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])

View File

@ -9,7 +9,7 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import load_config from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -118,7 +118,7 @@ class SpeakerManager:
Returns: Returns:
Tuple[Dict, int]: speaker IDs and number of speakers. Tuple[Dict, int]: speaker IDs and number of speakers.
""" """
speakers = sorted({item[2] for item in items}) speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)} speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids) num_speakers = len(speaker_ids)
return speaker_ids, num_speakers return speaker_ids, num_speakers
@ -318,6 +318,42 @@ class SpeakerManager:
# TODO: implement speaker encoder # TODO: implement speaker encoder
raise NotImplementedError raise NotImplementedError
@staticmethod
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
"""Initialize a speaker manager from config
Args:
config (Coqpit): Config object.
samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names.
Defaults to None.
Returns:
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None)
)
if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False):
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "d_vector_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None)
)
return speaker_manager
def _set_file_path(path): def _set_file_path(path):
"""Find the speakers.json under the given path or the above it. """Find the speakers.json under the given path or the above it.
@ -414,7 +450,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
def get_speaker_weighted_sampler(items: list): def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items]) speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist() unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])

View File

@ -8,7 +8,7 @@ from torch.autograd import Variable
def gaussian(window_size, sigma): def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
return gauss / gauss.sum() return gauss / gauss.sum()
@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2 C1 = 0.01**2
C2 = 0.03 ** 2 C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

View File

@ -1,46 +1,9 @@
import os
from typing import Dict from typing import Dict
import numpy as np import numpy as np
import pkg_resources
import torch import torch
from torch import nn from torch import nn
from .text import phoneme_to_sequence, text_to_sequence
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable
if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(
text,
text_cleaner,
language if language else CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
),
dtype=np.int32,
)
else:
seq = np.asarray(
text_to_sequence(
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
),
dtype=np.int32,
)
return seq
def numpy_to_torch(np_array, dtype, cuda=False): def numpy_to_torch(np_array, dtype, cuda=False):
if np_array is None: if np_array is None:
@ -51,13 +14,6 @@ def numpy_to_torch(np_array, dtype, cuda=False):
return tensor return tensor
def numpy_to_tf(np_array, dtype):
if np_array is None:
return None
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
return tensor
def compute_style_mel(style_wav, ap, cuda=False): def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda: if cuda:
@ -103,53 +59,6 @@ def run_model_torch(
return outputs return outputs
def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TF")
if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
# TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
# get input and output details
input_details = model.get_input_details()
output_details = model.get_output_details()
# reshape input tensor for the new input shape
model.resize_tensor_input(input_details[0]["index"], inputs.shape)
model.allocate_tensors()
detail = input_details[0]
# input_shape = detail['shape']
model.set_tensor(detail["index"], inputs)
# run the model
model.invoke()
# collect outputs
decoder_output = model.get_tensor(output_details[0]["index"])
postnet_output = model.get_tensor(output_details[1]["index"])
# tflite model only returns feature frames
return decoder_output, postnet_output, None, None
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].numpy()
decoder_output = decoder_output[0].numpy()
alignment = alignments[0].numpy()
stop_tokens = stop_tokens[0].numpy()
return postnet_output, decoder_output, alignment, stop_tokens
def parse_outputs_tflite(postnet_output, decoder_output):
postnet_output = postnet_output[0]
decoder_output = decoder_output[0]
return postnet_output, decoder_output
def trim_silence(wav, ap): def trim_silence(wav, ap):
return wav[: ap.find_endpoint(wav)] return wav[: ap.find_endpoint(wav)]
@ -204,16 +113,12 @@ def synthesis(
text, text,
CONFIG, CONFIG,
use_cuda, use_cuda,
ap,
speaker_id=None, speaker_id=None,
style_wav=None, style_wav=None,
enable_eos_bos_chars=False, # pylint: disable=unused-argument
use_griffin_lim=False, use_griffin_lim=False,
do_trim_silence=False, do_trim_silence=False,
d_vector=None, d_vector=None,
language_id=None, language_id=None,
language_name=None,
backend="torch",
): ):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model. the vocoder model.
@ -231,9 +136,6 @@ def synthesis(
use_cuda (bool): use_cuda (bool):
Enable/disable CUDA. Enable/disable CUDA.
ap (TTS.tts.utils.audio.AudioProcessor):
The audio processor for extracting features and pre/post-processing audio.
speaker_id (int): speaker_id (int):
Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.
@ -251,74 +153,51 @@ def synthesis(
language_id (int): language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None. Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
language_name (str):
Language name corresponding to the language code used by the phonemizer. Defaults to None.
backend (str):
tf or torch. Defaults to "torch".
""" """
# GST processing # GST processing
style_mel = None style_mel = None
custom_symbols = None if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if style_wav: if isinstance(style_wav, dict):
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) style_mel = style_wav
elif CONFIG.has("gst") and CONFIG.gst and not style_wav: else:
if CONFIG.gst.gst_style_input_weights: style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = CONFIG.gst.gst_style_input_weights # convert text to sequence of token IDs
if hasattr(model, "make_symbols"): text_inputs = np.asarray(
custom_symbols = model.make_symbols(CONFIG) model.tokenizer.text_to_ids(text, language=language_id),
# preprocess the given text dtype=np.int32,
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name) )
# pass tensors to backend # pass tensors to backend
if backend == "torch": if speaker_id is not None:
if speaker_id is not None: speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if d_vector is not None: if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda) d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
if language_id is not None: if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda) language_id = id_to_torch(language_id, cuda=use_cuda)
if not isinstance(style_mel, dict): if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0) text_inputs = text_inputs.unsqueeze(0)
elif backend in ["tf", "tflite"]:
# TODO: handle speaker id for tf model
style_mel = numpy_to_tf(style_mel, tf.float32)
text_inputs = numpy_to_tf(text_inputs, tf.int32)
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice # synthesize voice
if backend == "torch": outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) model_outputs = outputs["model_outputs"]
model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy()
model_outputs = model_outputs[0].data.cpu().numpy() alignments = outputs["alignments"]
alignments = outputs["alignments"]
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results
wav = None wav = None
if hasattr(model, "END2END") and model.END2END: model_outputs = model_outputs.squeeze()
wav = model_outputs.squeeze(0) if model_outputs.ndim == 2: # [T, C_spec]
else:
if use_griffin_lim: if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG) wav = inv_spectrogram(model_outputs, model.ap, CONFIG)
# trim silence # trim silence
if do_trim_silence: if do_trim_silence:
wav = trim_silence(wav, ap) wav = trim_silence(wav, model.ap)
else: # [T,]
wav = model_outputs
return_dict = { return_dict = {
"wav": wav, "wav": wav,
"alignments": alignments, "alignments": alignments,

View File

@ -1,276 +1 @@
# -*- coding: utf-8 -*- from TTS.tts.utils.text.tokenizer import TTSTokenizer
# adapted from https://github.com/keithito/tacotron
import re
from typing import Dict, List
import gruut
from gruut_ipa import IPA
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols
# pylint: disable=unnecessary-comprehension
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
# Regular expression matching punctuations, ignoring empty space
PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
# Table for str.translate to fix gruut/TTS phoneme mismatch
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False):
"""Convert graphemes to phonemes.
Parameters:
text (str): text to phonemize
language (str): language of the text
Returns:
ph (str): phonemes as a string seperated by "|"
ph = "ɪ|g|ˈ|z|æ|m|p|ə|l"
"""
# TO REVIEW : How to have a good implementation for this?
if language == "zh-CN":
ph = chinese_text_to_phonemes(text)
return ph
if language == "ja-jp":
ph = japanese_text_to_phonemes(text)
return ph
if not gruut.is_language_supported(language):
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
# Use gruut for phonemization
ph_list = []
for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes):
for word in sentence:
if word.is_break:
# Use actual character for break phoneme (e.g., comma)
if ph_list:
# Join with previous word
ph_list[-1].append(word.text)
else:
# First word is punctuation
ph_list.append([word.text])
elif word.phonemes:
# Add phonemes for word
word_phonemes = []
for word_phoneme in word.phonemes:
if not keep_stress:
# Remove primary/secondary stress
word_phoneme = IPA.without_stress(word_phoneme)
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
if word_phoneme:
# Flatten phonemes
word_phonemes.extend(word_phoneme)
if word_phonemes:
ph_list.append(word_phonemes)
# Join and re-split to break apart dipthongs, suprasegmentals, etc.
ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list]
ph = "| ".join(ph_words)
return ph
def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def pad_with_eos_bos(phoneme_sequence, tp=None):
# pylint: disable=global-statement
global _phonemes_to_id, _bos, _eos
if tp:
_bos = tp["bos"]
_eos = tp["eos"]
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(
text: str,
cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _phonemes_to_id, _phonemes
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
to_phonemes = text2phone(clean_text, language, use_espeak_phonemes=use_espeak_phonemes)
if to_phonemes is None:
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
for phoneme in filter(None, to_phonemes.split("|")):
sequence += _phoneme_to_sequence(phoneme)
# Append EOS char
if enable_eos_bos:
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
# pylint: disable=global-statement
"""Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = ""
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence:
if symbol_id in _id_to_phonemes:
s = _id_to_phonemes[symbol_id]
result += s
return result.replace("}{", " ")
def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _symbol_to_id, _symbols
if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while text:
m = _CURLY_RE.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
if add_blank:
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
return sequence
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
"""Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement
global _id_to_symbol, _symbols
if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence))
if custom_symbols is not None:
_symbols = custom_symbols
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
elif tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def _symbols_to_sequence(syms):
return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
def _phoneme_to_sequence(phons):
return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s not in ["~", "^", "_"]
def _should_keep_phoneme(p):
return p in _phonemes_to_id and p not in ["~", "^", "_"]

View File

@ -0,0 +1,468 @@
from dataclasses import replace
from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
def parse_symbols():
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
# DEFAULT SET OF GRAPHEMES
_pad = "<PAD>"
_eos = "<EOS>"
_bos = "<BOS>"
_blank = "<BLNK>" # TODO: check if we need this alongside with PAD
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_punctuations = "!'(),-.:;? "
# DEFAULT SET OF IPA PHONEMES
# Phonemes definition (All IPA characters)
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
class BaseVocabulary:
"""Base Vocabulary class.
This class only needs a vocabulary dictionary without specifying the characters.
Args:
vocab (Dict): A dictionary of characters and their corresponding indices.
"""
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
self.vocab = vocab
self.pad = pad
self.blank = blank
self.bos = bos
self.eos = eos
@property
def pad_id(self) -> int:
"""Return the index of the padding character. If the padding character is not specified, return the length
of the vocabulary."""
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
"""Return the index of the blank character. If the blank character is not specified, return the length of
the vocabulary."""
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
@vocab.setter
def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod
def init_from_config(config, **kwargs):
"""Initialize from the given config."""
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
return (
BaseVocabulary(
config.characters.vocab_dict,
config.characters.pad,
config.characters.blank,
config.characters.bos,
config.characters.eos,
),
config,
)
return BaseVocabulary(**kwargs), config
@property
def num_chars(self):
"""Return number of tokens in the vocabulary."""
return len(self._vocab)
def char_to_id(self, char: str) -> int:
"""Map a character to an token ID."""
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
"""Map an token ID to a character."""
return self._id_to_char[idx]
class BaseCharacters:
"""🐸BaseCharacters class
Every new character class should inherit from this.
Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```.
If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method.
Args:
characters (str):
Main set of characters to be used in the vocabulary.
punctuations (str):
Characters to be treated as punctuation.
pad (str):
Special padding character that would be ignored by the model.
eos (str):
End of the sentence character.
bos (str):
Beginning of the sentence character.
blank (str):
Optional character used between characters by some models for better prosody.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
el
is_sorted (bool):
Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True.
"""
def __init__(
self,
characters: str = None,
punctuations: str = None,
pad: str = None,
eos: str = None,
bos: str = None,
blank: str = None,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
self._characters = characters
self._punctuations = punctuations
self._pad = pad
self._eos = eos
self._bos = bos
self._blank = blank
self.is_unique = is_unique
self.is_sorted = is_sorted
self._create_vocab()
@property
def pad_id(self) -> int:
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def characters(self):
return self._characters
@characters.setter
def characters(self, characters):
self._characters = characters
self._create_vocab()
@property
def punctuations(self):
return self._punctuations
@punctuations.setter
def punctuations(self, punctuations):
self._punctuations = punctuations
self._create_vocab()
@property
def pad(self):
return self._pad
@pad.setter
def pad(self, pad):
self._pad = pad
self._create_vocab()
@property
def eos(self):
return self._eos
@eos.setter
def eos(self, eos):
self._eos = eos
self._create_vocab()
@property
def bos(self):
return self._bos
@bos.setter
def bos(self, bos):
self._bos = bos
self._create_vocab()
@property
def blank(self):
return self._blank
@blank.setter
def blank(self, blank):
self._blank = blank
self._create_vocab()
@property
def vocab(self):
return self._vocab
@vocab.setter
def vocab(self, vocab):
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
@property
def num_chars(self):
return len(self._vocab)
def _create_vocab(self):
_vocab = self._characters
if self.is_unique:
_vocab = list(set(_vocab))
if self.is_sorted:
_vocab = sorted(_vocab)
_vocab = list(_vocab)
_vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
self.vocab = _vocab + list(self._punctuations)
if self.is_unique:
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
assert (
len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
), f" [!] There are duplicate characters in the character set. {duplicates}"
def char_to_id(self, char: str) -> int:
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
return self._id_to_char[idx]
def print_log(self, level: int = 0):
"""
Prints the vocabulary in a nice format.
"""
indent = "\t" * level
print(f"{indent}| > Characters: {self._characters}")
print(f"{indent}| > Punctuations: {self._punctuations}")
print(f"{indent}| > Pad: {self._pad}")
print(f"{indent}| > EOS: {self._eos}")
print(f"{indent}| > BOS: {self._bos}")
print(f"{indent}| > Blank: {self._blank}")
print(f"{indent}| > Vocab: {self.vocab}")
print(f"{indent}| > Num chars: {self.num_chars}")
@staticmethod
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument
"""Init your character class from a config.
Implement this method for your subclass.
"""
# use character set from config
if config.characters is not None:
return BaseCharacters(**config.characters), config
# return default character set
characters = BaseCharacters()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
characters=self._characters,
punctuations=self._punctuations,
pad=self._pad,
eos=self._eos,
bos=self._bos,
blank=self._blank,
is_unique=self.is_unique,
is_sorted=self.is_sorted,
)
class IPAPhonemes(BaseCharacters):
"""🐸IPAPhonemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using IPAPhonemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
blank (str):
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _phonemes,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
blank: str = _blank,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
"""Init a IPAPhonemes object from a model config
If characters are not defined in the config, it will be set to the default characters and the config
will be updated.
"""
# band-aid for compatibility with old models
if "characters" in config and config.characters is not None:
if "phonemes" in config.characters and config.characters.phonemes is not None:
config.characters["characters"] = config.characters["phonemes"]
return (
IPAPhonemes(
characters=config.characters["characters"],
punctuations=config.characters["punctuations"],
pad=config.characters["pad"],
eos=config.characters["eos"],
bos=config.characters["bos"],
blank=config.characters["blank"],
is_unique=config.characters["is_unique"],
is_sorted=config.characters["is_sorted"],
),
config,
)
# use character set from config
if config.characters is not None:
return IPAPhonemes(**config.characters), config
# return default character set
characters = IPAPhonemes()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
class Graphemes(BaseCharacters):
"""🐸Graphemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using graphemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _characters,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
blank: str = _blank,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
"""Init a Graphemes object from a model config
If characters are not defined in the config, it will be set to the default characters and the config
will be updated.
"""
if config.characters is not None:
# band-aid for compatibility with old models
if "phonemes" in config.characters:
return (
Graphemes(
characters=config.characters["characters"],
punctuations=config.characters["punctuations"],
pad=config.characters["pad"],
eos=config.characters["eos"],
bos=config.characters["bos"],
blank=config.characters["blank"],
is_unique=config.characters["is_unique"],
is_sorted=config.characters["is_sorted"],
),
config,
)
return Graphemes(**config.characters), config
characters = Graphemes()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
if __name__ == "__main__":
gr = Graphemes()
ph = IPAPhonemes()
gr.print_log()
ph.print_log()

View File

@ -19,7 +19,7 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str:
return phoneme + tone return phoneme + tone
def chinese_text_to_phonemes(text: str) -> str: def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str:
tokenized_text = jieba.cut(text, HMM=False) tokenized_text = jieba.cut(text, HMM=False)
tokenized_text = " ".join(tokenized_text) tokenized_text = " ".join(tokenized_text)
pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text)
@ -34,4 +34,4 @@ def chinese_text_to_phonemes(text: str) -> str:
else: # is ponctuation or other else: # is ponctuation or other
results += list(token) results += list(token)
return "|".join(results) return seperator.join(results)

View File

@ -1,12 +1,16 @@
"""Set of default text cleaners"""
# TODO: pick the cleaner for languages dynamically
import re import re
from anyascii import anyascii from anyascii import anyascii
from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text
from .abbreviations import abbreviations_en, abbreviations_fr from .english.abbreviations import abbreviations_en
from .number_norm import normalize_numbers from .english.number_norm import normalize_numbers as en_normalize_numbers
from .time import expand_time_english from .english.time_norm import expand_time_english
from .french.abbreviations import abbreviations_fr
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
@ -22,10 +26,6 @@ def expand_abbreviations(text, lang="en"):
return text return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text): def lowercase(text):
return text.lower() return text.lower()
@ -92,7 +92,17 @@ def english_cleaners(text):
# text = convert_to_ascii(text) # text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = expand_time_english(text) text = expand_time_english(text)
text = expand_numbers(text) text = en_normalize_numbers(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = en_normalize_numbers(text)
text = expand_abbreviations(text) text = expand_abbreviations(text)
text = replace_symbols(text) text = replace_symbols(text)
text = remove_aux_symbols(text) text = remove_aux_symbols(text)
@ -126,17 +136,6 @@ def chinese_mandarin_cleaners(text: str) -> str:
return text return text
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
# text = convert_to_ascii(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text): def multilingual_cleaners(text):
"""Pipeline for multilingual text""" """Pipeline for multilingual text"""
text = lowercase(text) text = lowercase(text)

View File

@ -0,0 +1,26 @@
import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]

View File

@ -1,30 +1,5 @@
import re import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
# List of (regular expression, replacement) pairs for abbreviations in french: # List of (regular expression, replacement) pairs for abbreviations in french:
abbreviations_fr = [ abbreviations_fr = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])

View File

@ -0,0 +1,57 @@
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)}
ESPEAK_LANGS = list(ESpeak.supported_languages().keys())
GRUUT_LANGS = list(Gruut.supported_languages())
# Dict setting default phonemizers for each language
DEF_LANG_TO_PHONEMIZER = {
"ja-jp": JA_JP_Phonemizer.name(),
"zh-cn": ZH_CN_Phonemizer.name(),
}
# Add Gruut languages
_ = [Gruut.name()] * len(GRUUT_LANGS)
_new_dict = dict(list(zip(GRUUT_LANGS, _)))
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
# Add ESpeak languages and override any existing ones
_ = [ESpeak.name()] * len(ESPEAK_LANGS)
_new_dict = dict(list(zip(list(ESPEAK_LANGS), _)))
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"]
def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
"""Initiate a phonemizer by name
Args:
name (str):
Name of the phonemizer that should match `phonemizer.name()`.
kwargs (dict):
Extra keyword arguments that should be passed to the phonemizer.
"""
if name == "espeak":
return ESpeak(**kwargs)
if name == "gruut":
return Gruut(**kwargs)
if name == "zh_cn_phonemizer":
return ZH_CN_Phonemizer(**kwargs)
if name == "ja_jp_phonemizer":
return JA_JP_Phonemizer(**kwargs)
raise ValueError(f"Phonemizer {name} not found")
if __name__ == "__main__":
print(DEF_LANG_TO_PHONEMIZER)

View File

@ -0,0 +1,141 @@
import abc
from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation
class BasePhonemizer(abc.ABC):
"""Base phonemizer class
Phonemization follows the following steps:
1. Preprocessing:
- remove empty lines
- remove punctuation
- keep track of punctuation marks
2. Phonemization:
- convert text to phonemes
3. Postprocessing:
- join phonemes
- restore punctuation marks
Args:
language (str):
Language used by the phonemizer.
punctuations (List[str]):
List of punctuation marks to be preserved.
keep_puncs (bool):
Whether to preserve punctuation marks or not.
"""
def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
# ensure the backend is installed on the system
if not self.is_available():
raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
# ensure the backend support the requested language
self._language = self._init_language(language)
# setup punctuation processing
self._keep_puncs = keep_puncs
self._punctuator = Punctuation(punctuations)
def _init_language(self, language):
"""Language initialization
This method may be overloaded in child classes (see Segments backend)
"""
if not self.is_supported_language(language):
raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
return language
@property
def language(self):
"""The language code configured to be used for phonemization"""
return self._language
@staticmethod
@abc.abstractmethod
def name():
"""The name of the backend"""
...
@classmethod
@abc.abstractmethod
def is_available(cls):
"""Returns True if the backend is installed, False otherwise"""
...
@classmethod
@abc.abstractmethod
def version(cls):
"""Return the backend version as a tuple (major, minor, patch)"""
...
@staticmethod
@abc.abstractmethod
def supported_languages():
"""Return a dict of language codes -> name supported by the backend"""
...
def is_supported_language(self, language):
"""Returns True if `language` is supported by the backend"""
return language in self.supported_languages()
@abc.abstractmethod
def _phonemize(self, text, separator):
"""The main phonemization method"""
def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
"""Preprocess the text before phonemization
1. remove spaces
2. remove punctuation
Override this if you need a different behaviour
"""
text = text.strip()
if self._keep_puncs:
# a tuple (text, punctuation marks)
return self._punctuator.strip_to_restore(text)
return [self._punctuator.strip(text)], []
def _phonemize_postprocess(self, phonemized, punctuations) -> str:
"""Postprocess the raw phonemized output
Override this if you need a different behaviour
"""
if self._keep_puncs:
return self._punctuator.restore(phonemized, punctuations)[0]
return phonemized[0]
def phonemize(self, text: str, separator="|") -> str:
"""Returns the `text` phonemized for the given language
Args:
text (str):
Text to be phonemized.
separator (str):
string separator used between phonemes. Default to '_'.
Returns:
(str): Phonemized text
"""
text, punctuations = self._phonemize_preprocess(text)
phonemized = []
for t in text:
p = self._phonemize(t, separator)
phonemized.append(p)
phonemized = self._phonemize_postprocess(phonemized, punctuations)
return phonemized
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}")
print(f"{indent}| > phoneme backend: {self.name()}")

View File

@ -0,0 +1,225 @@
import logging
import subprocess
from typing import Dict, List
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
def is_tool(name):
from shutil import which
return which(name) is not None
# priority: espeakng > espeak
if is_tool("espeak-ng"):
_DEF_ESPEAK_LIB = "espeak-ng"
elif is_tool("espeak"):
_DEF_ESPEAK_LIB = "espeak"
else:
_DEF_ESPEAK_LIB = None
def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
"""Run espeak with the given arguments."""
cmd = [
espeak_lib,
"-q",
"-b",
"1", # UTF8 text encoding
]
cmd.extend(args)
logging.debug("espeakng: executing %s", repr(cmd))
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as p:
res = iter(p.stdout.readline, b"")
if not sync:
p.stdout.close()
if p.stderr:
p.stderr.close()
if p.stdin:
p.stdin.close()
return res
res2 = []
for line in res:
res2.append(line)
p.stdout.close()
if p.stderr:
p.stderr.close()
if p.stdin:
p.stdin.close()
p.wait()
return res2
class ESpeak(BasePhonemizer):
"""ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P
Args:
language (str):
Valid language code for the used backend.
backend (str):
Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically
prefering `espeak-ng` over `espeak`. Defaults to None.
punctuations (str):
Characters to be treated as punctuation. Defaults to Punctuation.default_puncs().
keep_puncs (bool):
If True, keep the punctuations after phonemization. Defaults to True.
Example:
>>> from TTS.tts.utils.text.phonemizers import ESpeak
>>> phonemizer = ESpeak("tr")
>>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|")
'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.'
"""
_ESPEAK_LIB = _DEF_ESPEAK_LIB
def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True):
if self._ESPEAK_LIB is None:
raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.")
self.backend = self._ESPEAK_LIB
# band-aid for backwards compatibility
if language == "en":
language = "en-us"
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
if backend is not None:
self.backend = backend
@property
def backend(self):
return self._ESPEAK_LIB
@backend.setter
def backend(self, backend):
if backend not in ["espeak", "espeak-ng"]:
raise Exception("Unknown backend: %s" % backend)
self._ESPEAK_LIB = backend
# skip first two characters of the retuned text
# "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# ^^
self.num_skip_chars = 2
if backend == "espeak-ng":
# skip the first character of the retuned text
# "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# ^
self.num_skip_chars = 1
def auto_set_espeak_lib(self) -> None:
if is_tool("espeak-ng"):
self._ESPEAK_LIB = "espeak-ng"
elif is_tool("espeak"):
self._ESPEAK_LIB = "espeak"
else:
raise Exception("Cannot set backend automatically. espeak-ng or espeak not found")
@staticmethod
def name():
return "espeak"
def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str:
"""Convert input text to phonemes.
Args:
text (str):
Text to be converted to phonemes.
tie (bool, optional) : When True use a '͡' character between
consecutive characters of a single phoneme. Else separate phoneme
with '_'. This option requires espeak>=1.49. Default to False.
"""
# set arguments
args = ["-v", f"{self._language}"]
# espeak and espeak-ng parses `ipa` differently
if tie:
# use '͡' between phonemes
if self.backend == "espeak":
args.append("--ipa=1")
else:
args.append("--ipa=3")
else:
# split with '_'
if self.backend == "espeak":
args.append("--ipa=3")
else:
args.append("--ipa=1")
if tie:
args.append("--tie=%s" % tie)
args.append('"' + text + '"')
# compute phonemes
phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line))
phonemes += line.decode("utf8").strip()[self.num_skip_chars :] # skip initial redundant characters
return phonemes.replace("_", separator)
def _phonemize(self, text, separator=None):
return self.phonemize_espeak(text, separator, tie=False)
@staticmethod
def supported_languages() -> Dict:
"""Get a dictionary of supported languages.
Returns:
Dict: Dictionary of language codes.
"""
if _DEF_ESPEAK_LIB is None:
return {}
args = ["--voices"]
langs = {}
count = 0
for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True):
line = line.decode("utf8").strip()
if count > 0:
cols = line.split()
lang_code = cols[1]
lang_name = cols[3]
langs[lang_code] = lang_name
logging.debug("line: %s", repr(line))
count += 1
return langs
def version(self) -> str:
"""Get the version of the used backend.
Returns:
str: Version of the used backend.
"""
args = ["--version"]
for line in _espeak_exe(self.backend, args, sync=True):
version = line.decode("utf8").strip().split()[2]
logging.debug("line: %s", repr(line))
return version
@classmethod
def is_available(cls):
"""Return true if ESpeak is available else false"""
return is_tool("espeak") or is_tool("espeak-ng")
if __name__ == "__main__":
e = ESpeak(language="en-us")
print(e.supported_languages())
print(e.version())
print(e.language)
print(e.name())
print(e.is_available())
e = ESpeak(language="en-us", keep_puncs=False)
print("`" + e.phonemize("hello how are you today?") + "`")
e = ESpeak(language="en-us", keep_puncs=True)
print("`" + e.phonemize("hello how are you today?") + "`")

View File

@ -0,0 +1,151 @@
import importlib
from typing import List
import gruut
from gruut_ipa import IPA
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
# Table for str.translate to fix gruut/TTS phoneme mismatch
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
class Gruut(BasePhonemizer):
"""Gruut wrapper for G2P
Args:
language (str):
Valid language code for the used backend.
punctuations (str):
Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
keep_puncs (bool):
If true, keep the punctuations after phonemization. Defaults to True.
use_espeak_phonemes (bool):
If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
keep_stress (bool):
If true, keep the stress characters after phonemization. Defaults to False.
Example:
>>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
>>> phonemizer = Gruut('en-us')
>>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
"""
def __init__(
self,
language: str,
punctuations=Punctuation.default_puncs(),
keep_puncs=True,
use_espeak_phonemes=False,
keep_stress=False,
):
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
self.use_espeak_phonemes = use_espeak_phonemes
self.keep_stress = keep_stress
@staticmethod
def name():
return "gruut"
def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
"""Convert input text to phonemes.
Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
that constitude a single sound.
It doesn't affect 🐸TTS since it individually converts each character to token IDs.
Examples::
"hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
Args:
text (str):
Text to be converted to phonemes.
tie (bool, optional) : When True use a '͡' character between
consecutive characters of a single phoneme. Else separate phoneme
with '_'. This option requires espeak>=1.49. Default to False.
"""
ph_list = []
for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
for word in sentence:
if word.is_break:
# Use actual character for break phoneme (e.g., comma)
if ph_list:
# Join with previous word
ph_list[-1].append(word.text)
else:
# First word is punctuation
ph_list.append([word.text])
elif word.phonemes:
# Add phonemes for word
word_phonemes = []
for word_phoneme in word.phonemes:
if not self.keep_stress:
# Remove primary/secondary stress
word_phoneme = IPA.without_stress(word_phoneme)
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
if word_phoneme:
# Flatten phonemes
word_phonemes.extend(word_phoneme)
if word_phonemes:
ph_list.append(word_phonemes)
ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
ph = f"{separator} ".join(ph_words)
return ph
def _phonemize(self, text, separator):
return self.phonemize_gruut(text, separator, tie=False)
def is_supported_language(self, language):
"""Returns True if `language` is supported by the backend"""
return gruut.is_language_supported(language)
@staticmethod
def supported_languages() -> List:
"""Get a dictionary of supported languages.
Returns:
List: List of language codes.
"""
return list(gruut.get_supported_languages())
def version(self):
"""Get the version of the used backend.
Returns:
str: Version of the used backend.
"""
return gruut.__version__
@classmethod
def is_available(cls):
"""Return true if ESpeak is available else false"""
return importlib.util.find_spec("gruut") is not None
if __name__ == "__main__":
e = Gruut(language="en-us")
print(e.supported_languages())
print(e.version())
print(e.language)
print(e.name())
print(e.is_available())
e = Gruut(language="en-us", keep_puncs=False)
print("`" + e.phonemize("hello how are you today?") + "`")
e = Gruut(language="en-us", keep_puncs=True)
print("`" + e.phonemize("hello how, are you today?") + "`")

View File

@ -0,0 +1,72 @@
from typing import Dict
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】"
_TRANS_TABLE = {"": ","}
def trans(text):
for i, j in _TRANS_TABLE.items():
text = text.replace(i, j)
return text
class JA_JP_Phonemizer(BasePhonemizer):
"""🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer`
TODO: someone with JA knowledge should check this implementation
Example:
>>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer
>>> phonemizer = JA_JP_Phonemizer()
>>> phonemizer.phonemize("どちらに行きますか?", separator="|")
'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?'
"""
language = "ja-jp"
def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
@staticmethod
def name():
return "ja_jp_phonemizer"
def _phonemize(self, text: str, separator: str = "|") -> str:
ph = japanese_text_to_phonemes(text)
if separator is not None or separator != "":
return separator.join(ph)
return ph
def phonemize(self, text: str, separator="|") -> str:
"""Custom phonemize for JP_JA
Skip pre-post processing steps used by the other phonemizers.
"""
return self._phonemize(text, separator)
@staticmethod
def supported_languages() -> Dict:
return {"ja-jp": "Japanese (Japan)"}
def version(self) -> str:
return "0.0.1"
def is_available(self) -> bool:
return True
# if __name__ == "__main__":
# text = "これは、電話をかけるための私の日本語の例のテキストです。"
# e = JA_JP_Phonemizer()
# print(e.supported_languages())
# print(e.version())
# print(e.language)
# print(e.name())
# print(e.is_available())
# print("`" + e.phonemize(text) + "`")

View File

@ -0,0 +1,55 @@
from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
class MultiPhonemizer:
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
Args:
custom_lang_to_phonemizer (Dict):
Custom phonemizer mapping if you want to change the defaults. In the format of
`{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`.
TODO: find a way to pass custom kwargs to the phonemizers
"""
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER
language = "multi-lingual"
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer)
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items():
phonemizer = get_phonemizer_by_name(v, language=k)
lang_to_phonemizer[k] = phonemizer
return lang_to_phonemizer
@staticmethod
def name():
return "multi-phonemizer"
def phonemize(self, text, language, separator="|"):
return self.lang_to_phonemizer[language].phonemize(text, separator)
def supported_languages(self) -> List:
return list(self.lang_to_phonemizer_name.keys())
# if __name__ == "__main__":
# texts = {
# "tr": "Merhaba, bu Türkçe bit örnek!",
# "en-us": "Hello, this is English example!",
# "de": "Hallo, das ist ein Deutches Beipiel!",
# "zh-cn": "这是中国的例子",
# }
# phonemes = {}
# ph = MultiPhonemizer()
# for lang, text in texts.items():
# phoneme = ph.phonemize(text, lang)
# phonemes[lang] = phoneme
# print(phonemes)

View File

@ -0,0 +1,62 @@
from typing import Dict
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】"
class ZH_CN_Phonemizer(BasePhonemizer):
"""🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer`
Args:
punctuations (str):
Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`.
keep_puncs (bool):
If True, keep the punctuations after phonemization. Defaults to False.
Example ::
"这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| || |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |`
TODO: someone with Mandarin knowledge should check this implementation
"""
language = "zh-cn"
def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
@staticmethod
def name():
return "zh_cn_phonemizer"
@staticmethod
def phonemize_zh_cn(text: str, separator: str = "|") -> str:
ph = chinese_text_to_phonemes(text, separator)
return ph
def _phonemize(self, text, separator):
return self.phonemize_zh_cn(text, separator)
@staticmethod
def supported_languages() -> Dict:
return {"zh-cn": "Japanese (Japan)"}
def version(self) -> str:
return "0.0.1"
def is_available(self) -> bool:
return True
# if __name__ == "__main__":
# text = "这是,样本中文。"
# e = ZH_CN_Phonemizer()
# print(e.supported_languages())
# print(e.version())
# print(e.language)
# print(e.name())
# print(e.is_available())
# print("`" + e.phonemize(text) + "`")

View File

@ -0,0 +1,172 @@
import collections
import re
from enum import Enum
import six
_DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
class PuncPosition(Enum):
"""Enum for the punctuations positions"""
BEGIN = 0
END = 1
MIDDLE = 2
ALONE = 3
class Punctuation:
"""Handle punctuations in text.
Just strip punctuations from text or strip and restore them later.
Args:
puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
Example:
>>> punc = Punctuation()
>>> punc.strip("This is. example !")
'This is example'
>>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
>>> ' '.join(text_striped)
'This is example'
>>> text_restored = punc.restore(text_striped, punc_map)
>>> text_restored[0]
'This is. example !'
"""
def __init__(self, puncs: str = _DEF_PUNCS):
self.puncs = puncs
@staticmethod
def default_puncs():
"""Return default set of punctuations."""
return _DEF_PUNCS
@property
def puncs(self):
return self._puncs
@puncs.setter
def puncs(self, value):
if not isinstance(value, six.string_types):
raise ValueError("[!] Punctuations must be of type str.")
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
def strip(self, text):
"""Remove all the punctuations by replacing with `space`.
Args:
text (str): The text to be processed.
Example::
"This is. example !" -> "This is example "
"""
return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
def strip_to_restore(self, text):
"""Remove punctuations from text to restore them later.
Args:
text (str): The text to be processed.
Examples ::
"This is. example !" -> [["This is", "example"], [".", "!"]]
"""
text, puncs = self._strip_to_restore(text)
return text, puncs
def _strip_to_restore(self, text):
"""Auxiliary method for Punctuation.preserve()"""
matches = list(re.finditer(self.puncs_regular_exp, text))
if not matches:
return [text], []
# the text is only punctuations
if len(matches) == 1 and matches[0].group() == text:
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
# build a punctuation map to be used later to restore punctuations
puncs = []
for match in matches:
position = PuncPosition.MIDDLE
if match == matches[0] and text.startswith(match.group()):
position = PuncPosition.BEGIN
elif match == matches[-1] and text.endswith(match.group()):
position = PuncPosition.END
puncs.append(_PUNC_IDX(match.group(), position))
# convert str text to a List[str], each item is separated by a punctuation
splitted_text = []
for idx, punc in enumerate(puncs):
split = text.split(punc.punc)
prefix, suffix = split[0], punc.punc.join(split[1:])
splitted_text.append(prefix)
# if the text does not end with a punctuation, add it to the last item
if idx == len(puncs) - 1 and len(suffix) > 0:
splitted_text.append(suffix)
text = suffix
return splitted_text, puncs
@classmethod
def restore(cls, text, puncs):
"""Restore punctuation in a text.
Args:
text (str): The text to be processed.
puncs (List[str]): The list of punctuations map to be used for restoring.
Examples ::
['This is', 'example'], ['.', '!'] -> "This is. example!"
"""
return cls._restore(text, puncs, 0)
@classmethod
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
"""Auxiliary method for Punctuation.restore()"""
if not puncs:
return text
# nothing have been phonemized, returns the puncs alone
if not text:
return ["".join(m.mark for m in puncs)]
current = puncs[0]
if current.position == PuncPosition.BEGIN:
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
if current.position == PuncPosition.END:
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
if current.position == PuncPosition.ALONE:
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
# POSITION == MIDDLE
if len(text) == 1: # pragma: nocover
# a corner case where the final part of an intermediate
# mark (I) has not been phonemized
return cls._restore([text[0] + current.punc], puncs[1:], num)
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
# if __name__ == "__main__":
# punc = Punctuation()
# text = "This is. This is, example!"
# print(punc.strip(text))
# split_text, puncs = punc.strip_to_restore(text)
# print(split_text, " ---- ", puncs)
# restored_text = punc.restore(split_text, puncs)
# print(restored_text)

View File

@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
"""
def make_symbols(
characters,
phonemes=None,
punctuations="!'(),-.:;? ",
pad="_",
eos="~",
bos="^",
unique=True,
): # pylint: disable=redefined-outer-name
"""Function to create symbols and phonemes
TODO: create phonemes_to_id and symbols_to_id dicts here."""
_symbols = list(characters)
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
_phonemes = None
if phonemes is not None:
_phonemes_sorted = (
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
) # this is to keep previous models compatible.
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
# _symbols += _arpabet
return _symbols, _phonemes
_pad = "_"
_eos = "~"
_bos = "^"
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
_punctuations = "!'(),-.:;? "
# Phonemes definition (All IPA characters)
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
# Generate ALIEN language
# from random import shuffle
# shuffle(phonemes)
def parse_symbols():
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
if __name__ == "__main__":
print(" > TTS symbols {}".format(len(symbols)))
print(symbols)
print(" > TTS phonemes {}".format(len(phonemes)))
print("".join(sorted(phonemes)))

View File

@ -0,0 +1,205 @@
from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.utils.generic_utils import get_import_path, import_class
class TTSTokenizer:
"""🐸TTS tokenizer to convert input characters to token IDs and back.
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
Args:
use_phonemes (bool):
Whether to use phonemes instead of characters. Defaults to False.
characters (Characters):
A Characters object to use for character-to-ID and ID-to-character mappings.
text_cleaner (callable):
A function to pre-process the text before tokenization and phonemization. Defaults to None.
phonemizer (Phonemizer):
A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None.
Example:
>>> from TTS.tts.utils.text.tokenizer import TTSTokenizer
>>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes())
>>> text = "Hello world!"
>>> ids = tokenizer.text_to_ids(text)
>>> text_hat = tokenizer.ids_to_text(ids)
>>> assert text == text_hat
"""
def __init__(
self,
use_phonemes=False,
text_cleaner: Callable = None,
characters: "BaseCharacters" = None,
phonemizer: Union["Phonemizer", Dict] = None,
add_blank: bool = False,
use_eos_bos=False,
):
self.text_cleaner = text_cleaner
self.use_phonemes = use_phonemes
self.add_blank = add_blank
self.use_eos_bos = use_eos_bos
self.characters = characters
self.not_found_characters = []
self.phonemizer = phonemizer
@property
def characters(self):
return self._characters
@characters.setter
def characters(self, new_characters):
self._characters = new_characters
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
def encode(self, text: str) -> List[int]:
"""Encodes a string of text as a sequence of IDs."""
token_ids = []
for char in text:
try:
idx = self.characters.char_to_id(char)
token_ids.append(idx)
except KeyError:
# discard but store not found characters
if char not in self.not_found_characters:
self.not_found_characters.append(char)
print(text)
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
return token_ids
def decode(self, token_ids: List[int]) -> str:
"""Decodes a sequence of IDs to a string of text."""
text = ""
for token_id in token_ids:
text += self.characters.id_to_char(token_id)
return text
def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
"""Converts a string of text to a sequence of token IDs.
Args:
text(str):
The text to convert to token IDs.
language(str):
The language code of the text. Defaults to None.
TODO:
- Add support for language-specific processing.
1. Text normalizatin
2. Phonemization (if use_phonemes is True)
3. Add blank char between characters
4. Add BOS and EOS characters
5. Text to token IDs
"""
# TODO: text cleaner should pick the right routine based on the language
if self.text_cleaner is not None:
text = self.text_cleaner(text)
if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="")
if self.add_blank:
text = self.intersperse_blank_char(text, True)
if self.use_eos_bos:
text = self.pad_with_bos_eos(text)
return self.encode(text)
def ids_to_text(self, id_sequence: List[int]) -> str:
"""Converts a sequence of token IDs to a string of text."""
return self.decode(id_sequence)
def pad_with_bos_eos(self, char_sequence: List[str]):
"""Pads a sequence with the special BOS and EOS characters."""
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
"""
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence
return result
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > add_blank: {self.add_blank}")
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
if self.use_phonemes:
print(f"{indent}| > phonemizer:")
self.phonemizer.print_logs(level + 1)
if len(self.not_found_characters) > 0:
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
for char in self.not_found_characters:
print(f"{indent}| > {char}")
@staticmethod
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
"""Init Tokenizer object from config
Args:
config (Coqpit): Coqpit model config.
characters (BaseCharacters): Defines the model character set. If not set, use the default options based on
the config values. Defaults to None.
"""
# init cleaners
text_cleaner = None
if isinstance(config.text_cleaner, (str, list)):
text_cleaner = getattr(cleaners, config.text_cleaner)
# init characters
if characters is None:
# set characters based on defined characters class
if config.characters and config.characters.characters_class:
CharactersClass = import_class(config.characters.characters_class)
characters, new_config = CharactersClass.init_from_config(config)
# set characters based on config
else:
if config.use_phonemes:
# init phoneme set
characters, new_config = IPAPhonemes().init_from_config(config)
else:
# init character set
characters, new_config = Graphemes().init_from_config(config)
else:
characters, new_config = characters.init_from_config(config)
# set characters class
new_config.characters.characters_class = get_import_path(characters)
# init phonemizer
phonemizer = None
if config.use_phonemes:
phonemizer_kwargs = {"language": config.phoneme_language}
if "phonemizer" in config and config.phonemizer:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
else:
try:
phonemizer = get_phonemizer_by_name(
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
)
except KeyError as e:
raise ValueError(
f"""No phonemizer found for language {config.phoneme_language}.
You may need to install a third party library for this language."""
) from e
return (
TTSTokenizer(
config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars
),
new_config,
)

View File

@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
matplotlib.use("Agg") matplotlib.use("Agg")
@ -89,12 +87,46 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False)
return fig return fig
def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
"""Plot pitch curves on top of the input characters.
Args:
pitch (np.array): Pitch values.
chars (str): Characters to place to the x-axis.
Shapes:
pitch: :math:`(T,)`
"""
old_fig_size = plt.rcParams["figure.figsize"]
if fig_size is not None:
plt.rcParams["figure.figsize"] = fig_size
fig, ax = plt.subplots()
x = np.array(range(len(chars)))
my_xticks = chars
plt.xticks(x, my_xticks)
ax.set_xlabel("characters")
ax.set_ylabel("freq")
ax2 = ax.twinx()
ax2.plot(pitch, linewidth=5.0, color="red")
ax2.set_ylabel("F0")
plt.rcParams["figure.figsize"] = old_fig_size
if not output_fig:
plt.close()
return fig
def visualize( def visualize(
alignment, alignment,
postnet_output, postnet_output,
text, text,
hop_length, hop_length,
CONFIG, CONFIG,
tokenizer,
stop_tokens=None, stop_tokens=None,
decoder_output=None, decoder_output=None,
output_path=None, output_path=None,
@ -117,14 +149,8 @@ def visualize(
plt.ylabel("Encoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
# compute phoneme representation and back # compute phoneme representation and back
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = phoneme_to_sequence( seq = tokenizer.text_to_ids(text)
text, text = tokenizer.ids_to_text(seq)
[CONFIG.text_cleaner],
CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
)
text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None)
print(text) print(text)
plt.yticks(range(len(text)), list(text)) plt.yticks(range(len(text)), list(text))
plt.colorbar() plt.colorbar()

View File

@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
) )
M = o[:, :, :, 0] M = o[:, :, :, 0]
P = o[:, :, :, 1] P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None: if self.power is not None:
S = S ** self.power S = S**self.power
if self.use_mel: if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S) S = torch.matmul(self.mel_basis.to(x), S)
@ -239,6 +239,12 @@ class AudioProcessor(object):
mel_fmax (int, optional): mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None. maximum filter frequency for computing melspectrograms. Defaults to None.
pitch_fmin (int, optional):
minimum filter frequency for computing pitch. Defaults to None.
pitch_fmax (int, optional):
maximum filter frequency for computing pitch. Defaults to None.
spec_gain (int, optional): spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20. gain applied when converting amplitude to DB. Defaults to 20.
@ -300,6 +306,8 @@ class AudioProcessor(object):
max_norm=None, max_norm=None,
mel_fmin=None, mel_fmin=None,
mel_fmax=None, mel_fmax=None,
pitch_fmax=None,
pitch_fmin=None,
spec_gain=20, spec_gain=20,
stft_pad_mode="reflect", stft_pad_mode="reflect",
clip_norm=True, clip_norm=True,
@ -333,6 +341,8 @@ class AudioProcessor(object):
self.symmetric_norm = symmetric_norm self.symmetric_norm = symmetric_norm
self.mel_fmin = mel_fmin or 0 self.mel_fmin = mel_fmin or 0
self.mel_fmax = mel_fmax self.mel_fmax = mel_fmax
self.pitch_fmin = pitch_fmin
self.pitch_fmax = pitch_fmax
self.spec_gain = float(spec_gain) self.spec_gain = float(spec_gain)
self.stft_pad_mode = stft_pad_mode self.stft_pad_mode = stft_pad_mode
self.max_norm = 1.0 if max_norm is None else float(max_norm) self.max_norm = 1.0 if max_norm is None else float(max_norm)
@ -379,6 +389,12 @@ class AudioProcessor(object):
self.clip_norm = None self.clip_norm = None
self.symmetric_norm = None self.symmetric_norm = None
@staticmethod
def init_from_config(config: "Coqpit", verbose=True):
if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config)
### setting up the parameters ### ### setting up the parameters ###
def _build_mel_basis( def _build_mel_basis(
self, self,
@ -634,8 +650,8 @@ class AudioProcessor(object):
S = self._db_to_amp(S) S = self._db_to_amp(S)
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S ** self.power) return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
@ -643,8 +659,8 @@ class AudioProcessor(object):
S = self._db_to_amp(D) S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0: if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S ** self.power) return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram. """Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -720,11 +736,12 @@ class AudioProcessor(object):
>>> WAV_FILE = filename = librosa.util.example_audio_file() >>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig >>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor >>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(mel_fmax=8000) >>> conf = BaseAudioConfig(pitch_fmax=8000)
>>> ap = AudioProcessor(**conf) >>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
>>> pitch = ap.compute_f0(wav) >>> pitch = ap.compute_f0(wav)
""" """
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
# align F0 length to the spectrogram length # align F0 length to the spectrogram length
if len(x) % self.hop_length == 0: if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode="reflect") x = np.pad(x, (0, self.hop_length // 2), mode="reflect")
@ -732,7 +749,7 @@ class AudioProcessor(object):
f0, t = pw.dio( f0, t = pw.dio(
x.astype(np.double), x.astype(np.double),
fs=self.sample_rate, fs=self.sample_rate,
f0_ceil=self.mel_fmax, f0_ceil=self.pitch_fmax,
frame_period=1000 * self.hop_length / self.sample_rate, frame_period=1000 * self.hop_length / self.sample_rate,
) )
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
@ -781,7 +798,7 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def _rms_norm(wav, db_level=-27): def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20) r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2)) a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
@ -853,7 +870,7 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2 ** qc - 1 mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0) # wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels. # Quantize signal to the specified number of levels.
@ -865,13 +882,13 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def mulaw_decode(wav, qc): def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values.""" """Recovers waveform from quantized values."""
mu = 2 ** qc - 1 mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x return x
@staticmethod @staticmethod
def encode_16bits(x): def encode_16bits(x):
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod @staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray: def quantize(x: np.ndarray, bits: int) -> np.ndarray:
@ -884,12 +901,12 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Quantized waveform. np.ndarray: Quantized waveform.
""" """
return (x + 1.0) * (2 ** bits - 1) / 2 return (x + 1.0) * (2**bits - 1) / 2
@staticmethod @staticmethod
def dequantize(x, bits): def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits.""" """Dequantize a waveform from the given number of bits."""
return 2 * x / (2 ** bits - 1) - 1 return 2 * x / (2**bits - 1) - 1
def _log(x, base): def _log(x, base):

View File

@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
while True: while True:
# Read by chunk to avoid filling memory # Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2) chunk = file_obj.read(1024**2)
if not chunk: if not chunk:
break break
hash_func.update(chunk) hash_func.update(chunk)

Some files were not shown because too many files have changed in this diff Show More