Merge pull request #1288 from coqui-ai/dev

v0.6.0
This commit is contained in:
Eren Gölge 2022-03-07 14:05:30 +01:00 committed by GitHub
commit 209ee40c88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
208 changed files with 6477 additions and 6934 deletions

View File

@ -1,16 +0,0 @@
#!/bin/bash
yes | apt-get install sox
yes | apt-get install ffmpeg
yes | apt-get install tmux
yes | apt-get install zsh
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
sudo sh install.sh
# pip install pytorch==1.7.0+cu100
# python3 setup.py develop
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
# python3 distribute.py --config_path config.json
while true; do sleep 1000000; done

46
.github/workflows/data_tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: data-tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make data_tests

46
.github/workflows/inference_tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: inference_tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make inference_tests

48
.github/workflows/text_tests.yml vendored Normal file
View File

@ -0,0 +1,48 @@
name: tts-tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make test_text

View File

@ -35,6 +35,8 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel

View File

@ -35,6 +35,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y git make gcc
sudo apt-get install espeak espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel

3
.gitignore vendored
View File

@ -164,4 +164,5 @@ internal/*
*_pitch.npy
*_phoneme.npy
wandb
depot/*
depot/*
coqui_recipes/*

View File

@ -168,7 +168,8 @@ disable=missing-docstring,
exception-escape,
comprehension-escape,
duplicate-code,
not-callable
not-callable,
import-outside-toplevel
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option

View File

@ -26,6 +26,15 @@ test_aux: ## run aux tests.
test_zoo: ## run zoo tests.
nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id
inference_tests: ## run inference tests.
nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id
data_tests: ## run data tests.
nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id
test_text: ## run text tests.
nosetests tests.text_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.text_tests --nologcapture --with-id
test_failed: ## only run tests failed the last time.
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed
@ -41,7 +50,6 @@ system-deps: ## install linux system deps
dev-deps: ## install development deps
pip install -r requirements.dev.txt
pip install -r requirements.tf.txt
doc-deps: ## install docs dependencies
pip install -r docs/requirements.txt

View File

@ -61,8 +61,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Detailed training logs on the terminal and Tensorboard.
- Support for Multi-speaker TTS.
- Efficient, flexible, lightweight but feature complete `Trainer API`.
- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
- Released and read-to-use models.
- Released and ready-to-use models.
- Tools to curate Text2Speech datasets under```dataset_analysis```.
- Utilities to use and test your models.
- Modular (but not too much) code base enabling easy implementation of new ideas.
@ -113,17 +112,11 @@ If you are only interested in [synthesizing speech](https://tts.readthedocs.io/e
pip install TTS
```
By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra.
```bash
pip install TTS[tf]
```
If you plan to code or train models, clone 🐸TTS and install it locally.
```bash
git clone https://github.com/coqui-ai/TTS
pip install -e .[all,dev,notebooks,tf] # Select the relevant extras
pip install -e .[all,dev,notebooks] # Select the relevant extras
```
If you are on Ubuntu (Debian), you can also run following commands for installation.
@ -204,12 +197,10 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|- train*.py (train your target model.)
|- distribute.py (train your TTS model using Multiple GPUs.)
|- compute_statistics.py (compute dataset statistics for normalization.)
|- convert*.py (convert target torch model to TF.)
|- ...
|- tts/ (text to speech models)
|- layers/ (model layer definitions)
|- models/ (model definitions)
|- tf/ (Tensorflow 2 utilities and model implementations)
|- utils/ (model specific utilities.)
|- speaker_encoder/ (Speaker Encoder models.)
|- (same)

View File

@ -4,7 +4,7 @@
"multi-dataset":{
"your_tts":{
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"default_vocoder": null,
"commit": "e9a1953e",
"license": "CC BY-NC-ND 4.0",
@ -33,7 +33,7 @@
},
"tacotron2-DDC_ph": {
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip",
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
"commit": "3900448",
"author": "Eren Gölge @erogol",
@ -71,7 +71,7 @@
},
"vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--vits.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren Gölge @erogol",
@ -89,18 +89,9 @@
}
},
"vctk": {
"sc-glow-tts": {
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
"default_vocoder": "vocoder_models/en/vctk/hifigan_v2",
"commit": "b531fa69",
"author": "Edresson Casanova",
"license": "",
"contact": ""
},
"vits": {
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--vctk--vits.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren @erogol",
@ -109,7 +100,7 @@
},
"fast_pitch":{
"description": "FastPitch model trained on VCTK dataseset.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--fast_pitch.zip",
"default_vocoder": null,
"commit": "bdab788d",
"author": "Eren @erogol",
@ -156,7 +147,7 @@
"uk":{
"mai": {
"glow-tts": {
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--uk--mailabs--glow-tts.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--uk--mai--glow-tts.zip",
"author":"@robinhad",
"commit": "bdab788d",
"license": "MIT",
@ -168,7 +159,7 @@
"zh-CN": {
"baker": {
"tacotron2-DDC-GST": {
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
"commit": "unknown",
"author": "@kirianguiller",
"default_vocoder": null
@ -206,6 +197,52 @@
"commit": "401fbd89"
}
}
},
"tr":{
"common-voice": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--tr--common-voice--glow-tts.zip",
"default_vocoder": "vocoder_models/tr/common-voice/hifigan",
"license": "MIT",
"description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.",
"author": "Fatih Akademi",
"commit": null
}
}
},
"it": {
"mai_female": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--glow-tts.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
},
"vits":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--vits.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
}
},
"mai_male": {
"glow-tts":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--glow-tts.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
},
"vits":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--vits.zip",
"default_vocoder": null,
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
"author": "@nicolalandro",
"commit": null
}
}
}
},
"vocoder_models": {
@ -324,6 +361,17 @@
"contact": ""
}
}
},
"tr":{
"common-voice": {
"hifigan":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/vocoder_models--tr--common-voice--hifigan.zip",
"description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.",
"author": "Fatih Akademi",
"license": "MIT",
"commit": null
}
}
}
}
}

View File

@ -1 +1 @@
0.5.0
0.6.0

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_checkpoint

View File

@ -29,6 +29,9 @@ parser.add_argument(
help="Path to dataset config file.",
)
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument(
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
@ -40,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
)
# compute speaker embeddings
@ -52,11 +58,15 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
else:
speaker_name = None
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
# get the embedding from the old file
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
else:
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
# create speaker_mapping if target dataset is defined
wav_file_name = os.path.basename(wav_file)
speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["embedding"] = embedd

View File

@ -51,7 +51,7 @@ def main():
N = 0
for item in tqdm(dataset_items):
# compute features
wav = ap.load_wav(item if isinstance(item, str) else item[1])
wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav)
@ -59,13 +59,13 @@ def main():
N += mel.shape[1]
mel_sum += mel.sum(1)
linear_sum += linear.sum(1)
mel_square_sum += (mel ** 2).sum(axis=1)
linear_square_sum += (linear ** 2).sum(axis=1)
mel_square_sum += (mel**2).sum(axis=1)
linear_square_sum += (linear**2).sum(axis=1)
mel_mean = mel_sum / N
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
output_file_path = args.out_path
stats = {}

View File

@ -1,25 +0,0 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary
import argparse
from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.generic_utils import setup_generator
from TTS.vocoder.tf.utils.io import load_checkpoint
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
parser = argparse.ArgumentParser()
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants
CONFIG = load_config(args.config_path)
# load the model
model = setup_generator(CONFIG)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
# create tflite model
tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)

View File

@ -1,105 +0,0 @@
import argparse
import os
from difflib import SequenceMatcher
import numpy as np
import tensorflow as tf
import torch
from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf,
convert_tf_name,
transfer_weights_torch_to_tf,
)
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator
# prevent GPU use
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# define args
parser = argparse.ArgumentParser()
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0
# init torch model
model = setup_generator(c)
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()
# init tf model
model_tf = setup_tf_generator(c)
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
mel_pred = model_tf(dummy_input, training=False)
# get tf variables
tf_vars = model_tf.weights
# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
var_map = []
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))
# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
# Compare TF and TORCH models
# check embedding outputs
model.eval()
dummy_input_torch = torch.ones((1, 80, 10))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
out_torch = model.layers[0](dummy_input_torch)
out_tf = model_tf.model_layers[0](dummy_input_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
assert compare_torch_tf(out_torch, out_tf_) < 1e-5
for i in range(1, len(model.layers)):
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
out_torch = model.layers[i](out_torch)
out_tf = model_tf.model_layers[i](out_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
diff = compare_torch_tf(out_torch, out_tf_)
assert diff < 1e-5, diff
torch.manual_seed(0)
dummy_input_torch = torch.rand((1, 80, 100))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
model.inference_padding = 0
model_tf.inference_padding = 0
output_torch = model.inference(dummy_input_torch)
output_tf = model_tf(dummy_input_tf, training=False)
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf)
# save tf model
save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -1,30 +0,0 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary
import argparse
from TTS.tts.tf.utils.generic_utils import setup_model
from TTS.tts.tf.utils.io import load_checkpoint
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
parser = argparse.ArgumentParser()
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants
CONFIG = load_config(args.config_path)
# load the model
c = CONFIG
num_speakers = 0
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)
# create tflite model
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)

View File

@ -1,187 +0,0 @@
import argparse
import os
import sys
from difflib import SequenceMatcher
from pprint import pprint
import numpy as np
import tensorflow as tf
import torch
from TTS.tts.models import setup_model
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
parser = argparse.ArgumentParser()
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0
# init torch model
model = setup_model(c)
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
# init tf model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_tf = Tacotron2(
num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
)
# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
var_map = [
("embedding/embeddings:0", "embedding.weight"),
("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"),
("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"),
("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"),
("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"),
("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")),
(
"encoder/lstm/backward_lstm/lstm_cell_2/bias:0",
("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"),
),
("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"),
("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"),
("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"),
]
# %%
# get tf_model graph
model_tf.build_inference()
# get tf variables
tf_vars = model_tf.weights
# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))
pprint(var_map)
pprint(torch_var_names)
# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
# Compare TF and TORCH models
# %%
# check embedding outputs
model.eval()
input_ids = torch.randint(0, 24, (1, 128)).long()
o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
# compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1, 2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
# pylint: disable=redefined-builtin
# compare decoder.attention_rnn
inp = torch.rand([1, 768])
inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5
query = output
inputs = torch.rand([1, 128, 512])
query_tf = query.detach().numpy()
inputs_tf = inputs.numpy()
# compare decoder.attention
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
assert compare_torch_tf(context, context_tf) < 1e-5
# compare decoder.decoder_rnn
input = torch.rand([1, 1536])
input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.linear_projection
input = torch.rand([1, 1536])
input_tf = input.numpy()
output = model.decoder.linear_projection(input)
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
assert compare_torch_tf(output, output_tf) < 1e-5
# compare decoder outputs
model.decoder.max_decoder_steps = 100
model_tf.decoder.set_max_decoder_steps(100)
output, align, stop = model.decoder.inference(oo_en)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
# compare the whole model output
outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
# %%
# save tf model
save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -7,15 +7,14 @@ import subprocess
import time
import torch
from TTS.trainer import TrainingArgs
from trainer import TrainerArgs
def main():
"""
Call train.py as a new process and pass command arguments
"""
parser = TrainingArgs().init_argparse(arg_prefix="")
parser = TrainerArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.")
args, unargs = parser.parse_known_args()

View File

@ -13,6 +13,7 @@ from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset(
r,
c.text_cleaner,
outputs_per_step=r,
compute_linear_spec=False,
meta_data=meta_data,
samples=meta_data,
tokenizer=tokenizer,
ap=ap,
characters=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
min_text_len=c.min_text_len,
max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
dataset.preprocess_samples()
loader = DataLoader(
dataset,
@ -75,8 +75,8 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data["text"]
text_lengths = data["text_lengths"]
text_input = data["token_id"]
text_lengths = data["token_id_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]
@ -138,7 +138,7 @@ def inference(
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
model_output = model_output.detach().cpu().numpy()
elif "tacotron" in model_name:
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)
# load data instances
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval)
meta_data_train, meta_data_eval = load_tts_samples(
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
# use eval and training partitions
meta_data = meta_data_train + meta_data_eval

View File

@ -23,7 +23,10 @@ def main():
c = load_config(args.config_path)
# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
train_items, eval_items = load_tts_samples(
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
items = train_items + eval_items
texts = "".join(item[0] for item in items)

View File

@ -7,14 +7,15 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text import text2phone
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
phonemizer = Gruut(language="en-us")
def compute_phonemes(item):
try:
text = item[0]
language = item[-1]
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
ph = phonemizer.phonemize(text).split("|")
except:
return []
return list(set(ph))
@ -39,10 +40,17 @@ def main():
c = load_config(args.config_path)
# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
train_items, eval_items = load_tts_samples(
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
)
items = train_items + eval_items
print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = []
for ph in phonemes:

View File

@ -26,6 +26,7 @@ if __name__ == "__main__":
--input_dir /root/LJSpeech-1.1/
--output_sr 22050
--output_dir /root/resampled_LJSpeech-1.1/
--file_ext wav
--n_jobs 24
""",
formatter_class=RawTextHelpFormatter,
@ -55,6 +56,14 @@ if __name__ == "__main__":
help="Path of the destination folder. If not defined, the operation is done in place",
)
parser.add_argument(
"--file_ext",
type=str,
default="wav",
required=False,
help="Extension of the audio files to resample",
)
parser.add_argument(
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
)
@ -67,7 +76,7 @@ if __name__ == "__main__":
args.input_dir = args.output_dir
print("Resampling the audio files...")
audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True)
audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True)
print(f"Found {len(audio_files)} files...")
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
with Pool(processes=args.n_jobs) as p:

View File

@ -8,6 +8,7 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
@ -19,7 +20,7 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

View File

@ -1,19 +1,22 @@
import os
import torch
from dataclasses import dataclass, field
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@dataclass
class TrainTTSArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainingArgs()
train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args
@ -41,45 +44,15 @@ def main():
config = register_config(config_base.model)()
# load training samples
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True)
# setup audio processor
ap = AudioProcessor(**config.audio)
# init speaker manager
if check_config_and_model_args(config, "use_speaker_embedding", True):
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
elif check_config_and_model_args(config, "use_d_vector_file", True):
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
speaker_manager = SpeakerManager(
d_vectors_file_path=config.model_args.d_vector_file,
encoder_model_path=config.model_args.speaker_encoder_model_path,
encoder_config_path=config.model_args.speaker_encoder_config_path,
use_cuda=torch.cuda.is_available(),
)
else:
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
config.num_speakers = speaker_manager.num_speakers
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
speaker_manager = None
if check_config_and_model_args(config, "use_language_embedding", True):
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
else:
config.num_languages = language_manager.num_languages
else:
language_manager = None
train_samples, eval_samples = load_tts_samples(
config.datasets,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the model from config
model = setup_model(config, speaker_manager, language_manager)
model = setup_model(config, train_samples + eval_samples)
# init the trainer and 🚀
trainer = Trainer(
@ -89,7 +62,6 @@ def main():
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
parse_command_line_args=False,
)
trainer.fit()

View File

@ -1,16 +1,23 @@
import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model
@dataclass
class TrainVocoderArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def main():
"""Run `tts` model training directly by a `config.json` file."""
# init trainer args
train_args = TrainingArgs()
train_args = TrainVocoderArgs()
parser = train_args.init_argparse(arg_prefix="")
# override trainer args from comman-line args

View File

@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass
from typing import List
from coqpit import Coqpit, check_argument
from trainer import TrainerConfig
@dataclass
@ -57,6 +58,12 @@ class BaseAudioConfig(Coqpit):
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
pitch_fmax (float, optional):
Maximum frequency of the F0 frames. Defaults to ```640```.
pitch_fmin (float, optional):
Minimum frequency of the F0 frames. Defaults to ```0```.
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
@ -135,6 +142,9 @@ class BaseAudioConfig(Coqpit):
spec_gain: int = 20
do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True
# f0 params
pitch_fmax: float = 640.0
pitch_fmin: float = 0.0
# normalization params
signal_norm: bool = True
min_level_db: int = -100
@ -228,130 +238,24 @@ class BaseDatasetConfig(Coqpit):
@dataclass
class BaseTrainingConfig(Coqpit):
"""Base config to define the basic training parameters that are shared
among all the models.
class BaseTrainingConfig(TrainerConfig):
"""Base config to define the basic 🐸TTS training parameters that are shared
among all the models. It is based on ```Trainer.TrainingConfig```.
Args:
model (str):
Name of the model that is used in the training.
run_name (str):
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
run_description (str):
Short description of the experiment.
epochs (int):
Number training epochs. Defaults to 10000.
batch_size (int):
Training batch size.
eval_batch_size (int):
Validation batch size.
mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time.
print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```.
print_step (int):
Number of steps required to print the next training log.
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard.
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```.
project_name (str):
Name of the project. Defaults to config.model
wandb_entity (str):
Name of W&B entity/team. Enables collaboration across a team or org.
log_model_step (int):
Number of steps required to log a checkpoint as W&B artifact
save_step (int):ipt
Number of steps required to save the next checkpoint.
checkpoint (bool):
Enable / Disable checkpointing.
keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```.
keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000.
num_loader_workers (int):
Number of workers for training time dataloader.
num_eval_loader_workers (int):
Number of workers for evaluation time dataloader.
output_path (str):
Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
"""
model: str = None
run_name: str = "coqui_tts"
run_description: str = ""
# training params
epochs: int = 10000
batch_size: int = None
eval_batch_size: int = None
mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params
run_eval: bool = True
test_delay_epochs: int = 0
print_eval: bool = False
# logging
dashboard_logger: str = "tensorboard"
print_step: int = 25
plot_step: int = 100
model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
# checkpointing
save_step: int = 10000
checkpoint: bool = True
keep_all_best: bool = False
keep_after: int = 10000
# dataloading
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
use_language_weighted_sampler: bool = False
# paths
output_path: str = None
# distributed
distributed_backend: str = "nccl"
distributed_url: str = "tcp://localhost:54321"

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
@ -9,28 +8,21 @@ from torch import nn
# pylint: skip-file
class BaseModel(nn.Module, ABC):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit):
"""Init the model from given config.
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Set model arguments from the config. Override this."""
pass
Override this depending on your model.
"""
...
@abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
"""Forward ... for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model.
@ -48,7 +40,7 @@ class BaseModel(nn.Module, ABC):
@abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference.
"""Forward ... for inference.
We don't use `*kwargs` since it is problematic with the TorchScript API.
@ -63,9 +55,25 @@ class BaseModel(nn.Module, ABC):
...
return outputs_dict
def format_batch(self, batch: Dict) -> Dict:
"""Format batch returned by the data loader before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
def format_batch_on_device(self, batch: Dict) -> Dict:
"""Format batch on device before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
@abstractmethod
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
"""Perform a single training step. Run the model forward ... and compute losses.
Args:
batch (Dict): Input tensors.
@ -93,11 +101,11 @@ class BaseModel(nn.Module, ABC):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
pass
...
@abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
"""Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
call `train_step()` with no changes.
Args:
@ -114,36 +122,49 @@ class BaseModel(nn.Module, ABC):
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`"""
pass
...
@abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference.
Args:
config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
"""
...
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
"""Setup an return optimizer or optimizers."""
pass
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
"""Init the model from given config.
def get_lr(self) -> Union[float, List[float]]:
"""Return learning rate(s).
Returns:
Union[float, List[float]]: Model's initial learning rates.
Override this depending on your model.
"""
pass
...
def get_scheduler(self, optimizer: torch.optim.Optimizer):
pass
@abstractmethod
def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
):
...
def get_criterion(self):
pass
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
# """Setup an return optimizer or optimizers."""
# ...
def format_batch(self):
pass
# def get_lr(self) -> Union[float, List[float]]:
# """Return learning rate(s).
# Returns:
# Union[float, List[float]]: Model's initial learning rates.
# """
# ...
# def get_scheduler(self, optimizer: torch.optim.Optimizer):
# ...
# def get_criterion(self):
# ...

View File

@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custome model paths
# CASE3: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
@ -170,9 +170,9 @@ def tts():
text = request.args.get("text")
speaker_idx = request.args.get("speaker_id", "")
style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)

View File

@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
sample = {
"mel": mel,
"item_idx": self.items[idx][1],
"item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name,
}
return sample
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
path_ = i[1]
speaker_ = i[2]
path_ = i["audio_file"]
speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:

View File

@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)

View File

@ -113,7 +113,7 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample(
self.noise_list[noise_type],
@ -135,7 +135,7 @@ class AugmentWAV(object):
self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"],
)
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None:
@ -154,7 +154,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2))
rir = rir / np.sqrt(np.sum(rir**2))
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio):

View File

@ -1,19 +1,24 @@
import os
from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from TTS.trainer import TrainingArgs
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
from TTS.utils.logging import init_dashboard_logger
from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.trainer_utils import get_last_checkpoint
@dataclass
class TrainArgs(TrainerArgs):
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
def getarguments():
train_config = TrainingArgs()
train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
@ -75,13 +80,13 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config)
dashboard_logger = logger_factory(config, experiment_path)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments():
train_config = TrainingArgs()
train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser

File diff suppressed because it is too large Load Diff

View File

@ -89,11 +89,11 @@ class FastPitchConfig(BaseTTSConfig):
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
@ -129,12 +129,12 @@ class FastPitchConfig(BaseTTSConfig):
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
pitch_loss_alpha: float = 0.1
dur_loss_alpha: float = 0.1
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13

View File

@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig):
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig):
pitch_loss_alpha: float = 0.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13

View File

@ -153,6 +153,7 @@ class GlowTTSConfig(BaseTTSConfig):
# multi-speaker settings
use_speaker_embedding: bool = False
speakers_file: str = None
use_d_vector_file: bool = False
d_vector_file: str = False

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field
from typing import List
from typing import Dict, List
from coqpit import Coqpit, check_argument
@ -50,9 +50,16 @@ class GSTConfig(Coqpit):
@dataclass
class CharactersConfig(Coqpit):
"""Defines character or phoneme set used by the model
"""Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses.
Args:
characters_class (str):
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
the configuration. Defaults to None.
vocab_dict (dict):
Defines the vocabulary dictionary used to encode the characters. Defaults to None.
pad (str):
characters in place of empty padding. Defaults to None.
@ -62,6 +69,9 @@ class CharactersConfig(Coqpit):
bos (str):
characters showing the beginning of a sentence. Defaults to None.
blank (str):
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
characters (str):
character set used by the model. Characters not in this list are ignored when converting input text to
a list of sequence IDs. Defaults to None.
@ -70,32 +80,32 @@ class CharactersConfig(Coqpit):
characters considered as punctuation as parsing the input sentence. Defaults to None.
phonemes (str):
characters considered as parsing phonemes. Defaults to None.
characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new
models. Defaults to None.
unique (bool):
is_unique (bool):
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
models trained with character lists with duplicates.
models trained with character lists with duplicates. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
characters_class: str = None
# using BaseVocabulary
vocab_dict: Dict = None
# using on BaseCharacters
pad: str = None
eos: str = None
bos: str = None
blank: str = None
characters: str = None
punctuations: str = None
phonemes: str = None
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
def check_values(
self,
):
"""Check config fields"""
c = asdict(self)
check_argument("pad", c, prerequest="characters", restricted=True)
check_argument("eos", c, prerequest="characters", restricted=True)
check_argument("bos", c, prerequest="characters", restricted=True)
check_argument("characters", c, prerequest="characters", restricted=True)
check_argument("phonemes", c, restricted=True)
check_argument("punctuations", c, prerequest="characters", restricted=True)
is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
is_sorted: bool = True
@dataclass
@ -110,8 +120,13 @@ class BaseTTSConfig(BaseTrainingConfig):
use_phonemes (bool):
enable / disable phoneme use.
use_espeak_phonemes (bool):
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
phonemizer (str):
Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`.
Defaults to None.
phoneme_language (str):
Language code for the phonemizer. You can check the list of supported languages by running
`python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None.
compute_input_seq_cache (bool):
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
@ -144,11 +159,19 @@ class BaseTTSConfig(BaseTrainingConfig):
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int):
Minimum sequence length to be used at training.
min_text_len (int):
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
max_seq_len (int):
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
max_text_len (int):
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
min_audio_len (int):
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
max_audio_len (int):
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
OOM error in training. Defaults to float("inf").
compute_f0 (int):
(Not in use yet).
@ -156,9 +179,16 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.
precompute_num_workers (int):
Number of workers to precompute features. Defaults to 0.
use_noise_augment (bool):
Augment the input audio with random noise.
start_by_longest (bool):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
@ -183,12 +213,19 @@ class BaseTTSConfig(BaseTrainingConfig):
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# phoneme settings
use_phonemes: bool = False
use_espeak_phonemes: bool = True
phonemizer: str = None
phoneme_language: str = None
compute_input_seq_cache: bool = False
text_cleaner: str = None
@ -197,17 +234,21 @@ class BaseTTSConfig(BaseTrainingConfig):
phoneme_cache_path: str = None
# vocabulary parameters
characters: CharactersConfig = None
add_blank: bool = False
# training params
batch_group_size: int = 0
loss_masking: bool = None
# dataloading
sort_by_audio_len: bool = False
min_seq_len: int = 1
max_seq_len: int = float("inf")
min_audio_len: int = 1
max_audio_len: int = float("inf")
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False
compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False
add_blank: bool = False
start_by_longest: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
@ -218,3 +259,6 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01

View File

@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 0.3
binary_align_loss_start_step: int = 50000
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13

View File

@ -83,6 +83,8 @@ class TacotronConfig(BaseTTSConfig):
ddc_r (int):
reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this
as a multiple of the `r` value. Defaults to 6.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
@ -176,6 +178,7 @@ class TacotronConfig(BaseTTSConfig):
ddc_r: int = 6
# multi-speaker settings
speakers_file: str = None
use_speaker_embedding: bool = False
speaker_embedding_dim: int = 512
use_d_vector_file: bool = False

View File

@ -17,7 +17,7 @@ class VitsConfig(BaseTTSConfig):
Model architecture arguments. Defaults to `VitsArgs()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.
@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -130,9 +121,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True
# overrides
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -9,25 +9,48 @@ from TTS.tts.datasets.dataset import *
from TTS.tts.datasets.formatters import *
def split_dataset(items):
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
Args:
<<<<<<< HEAD
items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
=======
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
>>>>>>> Fix docstring
"""
speakers = [item[-1] for item in items]
speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
if eval_split_size > 1:
eval_split_size = int(eval_split_size)
else:
if eval_split_max_size:
eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size))
else:
eval_split_size = int(len(items) * eval_split_size)
assert (
eval_split_size > 0
), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(
1 / len(items)
)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speakers = [item["speaker_name"] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
speaker_to_be_removed = items[item_idx]["speaker_name"]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
@ -37,7 +60,11 @@ def split_dataset(items):
def load_tts_samples(
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None
datasets: Union[List[Dict], Dict],
eval_split=True,
formatter: Callable = None,
eval_split_max_size=None,
eval_split_size=0.01,
) -> Tuple[List[List], List[List]]:
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
@ -52,9 +79,16 @@ def load_tts_samples(
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
must take the root_path and the meta_file name and return a list of samples in the format of
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
`[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
example. Defaults to None.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
@ -75,28 +109,28 @@ def load_tts_samples(
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train]
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval]
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for the duration predictor training
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)
attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].update({"alignment_file": attn_file})
if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip()
meta_data_eval_all[idx].append(attn_file)
attn_file = meta_data[ins["audio_file"]].strip()
meta_data_eval_all[idx].update({"alignment_file": attn_file})
# set none for the next iter
formatter = None
return meta_data_train_all, meta_data_eval_all

View File

@ -1,8 +1,7 @@
import collections
import os
import random
from multiprocessing import Pool
from typing import Dict, List
from typing import Dict, List, Union
import numpy as np
import torch
@ -10,87 +9,99 @@ import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")
def _parse_sample(item):
language_name = None
attn_file = None
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
elif len(item) == 4:
text, wav_file, speaker_name, language_name = item
elif len(item) == 3:
text, wav_file, speaker_name = item
else:
raise ValueError(" [!] Dataset cannot parse the sample.")
return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step: int,
text_cleaner: list,
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[List],
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
use_phonemes: bool = False,
min_text_len: int = 0,
max_text_len: int = float("inf"),
min_audio_len: int = 0,
max_audio_len: int = float("inf"),
phoneme_cache_path: str = None,
phoneme_language: str = "en-us",
enable_eos_bos: bool = False,
precompute_num_workers: int = 0,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False,
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can inherit and override.
If you need something different, you can subclass and override.
Args:
outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances.
samples (list): List of dataset samples.
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
use the given. Defaults to None.
compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
Defaults to float("inf").
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
Defaults to 0.
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
this value if you encounter an OOM error in training. Defaults to float("inf").
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
@ -99,285 +110,254 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.batch_group_size = batch_group_size
self.items = meta_data
self._samples = samples
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
self.tokenizer = tokenizer
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
)
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))
print(" | > Number of instances : {}".format(len(self.items)))
self.print_logs()
@property
def lengths(self):
lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
lens.append(audio_len)
return lens
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, new_samples):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.load_data(idx)
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
waveform = self.ap.load_wav(filename)
assert waveform.size > 0
return waveform
def get_phonemes(self, idx, text):
out_dict = self.phoneme_dataset[idx]
assert text == out_dict["text"], f"{text} != {out_dict['text']}"
assert len(out_dict["token_ids"]) > 0
return out_dict
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"]
return out_dict
@staticmethod
def load_np(filename):
data = np.load(filename).astype("float32")
return data
def get_attn_mask(attn_file):
return np.load(attn_file)
@staticmethod
def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars.
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
def get_token_ids(self, idx, text):
if self.tokenizer.use_phonemes:
token_ids = self.get_phonemes(idx, text)["token_ids"]
else:
token_ids = self.tokenizer.text_to_ids(text)
return np.array(token_ids, dtype=np.int32)
def load_data(self, idx):
item = self.items[idx]
item = self.samples[idx]
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
raw_text = item["text"]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
# apply noise for augmentation
if self.use_noise_augment:
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
wav = noise_augment_audio(wav)
if not self.input_seq_computed:
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
wav_file,
text,
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
language_name if language_name else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
# get token ids
token_ids = self.get_token_ids(idx, item["text"])
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
# get pre-computed attention maps
attn = None
if "alignment_file" in item:
attn = self.get_attn_mask(item["alignment_file"])
if "attn_file" in locals():
attn = np.load(attn_file)
if len(text) > self.max_seq_len:
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
# after phonemization the text length may change
# this is a shareful 🤭 hack to prevent longer phonemes
# TODO: find a better fix
if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
self.rescue_item_idx += 1
return self.load_data(self.rescue_item_idx)
pitch = None
# get f0 values
f0 = None
if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
f0 = self.get_f0(idx)["f0"]
sample = {
"raw_text": raw_text,
"text": text,
"token_ids": token_ids,
"wav": wav,
"pitch": pitch,
"pitch": f0,
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
"language_name": language_name,
"wav_file_name": os.path.basename(wav_file),
"item_idx": item["audio_file"],
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"wav_file_name": os.path.basename(item["audio_file"]),
}
return sample
@staticmethod
def _phoneme_worker(args):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
func_args[3] = (
item[3] if item[3] else func_args[3]
) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
def _compute_lengths(samples):
new_samples = []
for item in samples:
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
text_lenght = len(item["text"])
item["audio_length"] = audio_length
item["text_length"] = text_lenght
new_samples += [item]
return new_samples
def compute_input_seq(self, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
@staticmethod
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
idxs = np.argsort(lengths) # ascending order
ignore_idx = []
keep_idx = []
for idx in idxs:
length = lengths[idx]
if length < min_len or length > max_len:
ignore_idx.append(idx)
else:
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
keep_idx.append(idx)
return ignore_idx, keep_idx
def sort_and_filter_items(self, by_audio_len=False):
@staticmethod
def sort_by_length(samples: List[List]):
audio_lengths = [s["audio_length"] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order
return idxs
@staticmethod
def create_buckets(samples, batch_group_size: int):
assert batch_group_size > 0
for i in range(len(samples) // batch_group_size):
offset = i * batch_group_size
end_offset = offset + batch_group_size
temp_items = samples[offset:end_offset]
random.shuffle(temp_items)
samples[offset:end_offset] = temp_items
return samples
@staticmethod
def _select_samples_by_idx(idxs, samples):
samples_new = []
for idx in idxs:
samples_new.append(samples[idx])
return samples_new
def preprocess_samples(self):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths)
else:
lengths = np.array([len(ins[0]) for ins in self.items])
samples = self._compute_lengths(self.samples)
# sort items based on the sequence length in ascending order
text_lengths = [i["text_length"] for i in samples]
audio_lengths = [i["audio_length"] for i in samples]
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
samples = self._select_samples_by_idx(keep_idx, samples)
sorted_idxs = self.sort_by_length(samples)
if self.start_by_longest:
longest_idxs = sorted_idxs[-1]
sorted_idxs[-1] = sorted_idxs[0]
sorted_idxs[0] = longest_idxs
samples = self._select_samples_by_idx(sorted_idxs, samples)
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
idxs = np.argsort(lengths)
new_items = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
self.items = new_items
samples = self.create_buckets(samples, self.batch_group_size)
# update items to the new sorted items
audio_lengths = [s["audio_length"] for s in samples]
text_lengths = [s["text_length"] for s in samples]
self.samples = samples
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
self.max_seq_len, self.min_seq_len, len(ignored)
)
)
print(" | > Preprocessing samples")
print(" | > Max text length: {}".format(np.max(text_lengths)))
print(" | > Min text length: {}".format(np.min(text_lengths)))
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
print(" | ")
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
@ -402,10 +382,10 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch])
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
# sort items with text input length for RNN efficiency
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
@ -447,7 +427,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(batch["text"]).astype(np.int32)
token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
# PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step)
@ -456,12 +436,13 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lengths = torch.LongTensor(text_lengths)
text = torch.LongTensor(text)
token_ids_lengths = torch.LongTensor(token_ids_lengths)
token_ids = torch.LongTensor(token_ids)
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
# speaker vectors
if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors)
@ -472,14 +453,13 @@ class TTSDataset(Dataset):
language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram
linear = None
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
# format waveforms
wav_padded = None
@ -495,8 +475,7 @@ class TTSDataset(Dataset):
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# compute f0
# TODO: compare perf in collate_fn vs in load_data
# format F0
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
@ -504,23 +483,22 @@ class TTSDataset(Dataset):
else:
pitch = None
# collate attention alignments
# format attention masks
attns = None
if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
pad1 = token_ids.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step)
attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return {
"text": text,
"text_lengths": text_lengths,
"token_id": token_ids,
"token_id_lengths": token_ids_lengths,
"speaker_names": batch["speaker_name"],
"linear": linear,
"mel": mel,
@ -546,22 +524,185 @@ class TTSDataset(Dataset):
)
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
class PhonemeDataset(Dataset):
"""Phoneme Dataset for converting input text to phonemes and then token IDs
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation.
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
tokenizer (TTSTokenizer):
Tokenizer to convert input text to phonemes.
cache_path (str):
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0.
"""
def __init__(
self,
items: List[List],
verbose=False,
samples: Union[List[Dict], List[List]],
tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers=0,
):
self.items = items
self.samples = samples
self.tokenizer = tokenizer
self.cache_path = cache_path
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, wav_file, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
np.save(cache_path, ids)
return ids
def get_pad_id(self):
"""Get pad token ID for sequence padding"""
return self.tokenizer.pad_id
def precompute(self, num_workers=1):
"""Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy.
"""
print("[*] Pre-computing phonemes...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
for _ in dataloder:
pbar.update(batch_size)
def collate_fn(self, batch):
ids = [item["token_ids"] for item in batch]
ids_lens = [item["token_ids_len"] for item in batch]
texts = [item["text"] for item in batch]
texts_hat = [item["ph_hat"] for item in batch]
ids_lens_max = max(ids_lens)
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
for i, ids_len in enumerate(ids_lens):
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute F0 from wav files.
cache_path (str):
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the F0 values. Defaults to 0.
normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
self.mean = None
self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_f0:
self.load_stats(cache_path)
def __getitem__(self, idx):
item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"])
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": item["audio_file"], "f0": f0}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
normalize_f0 = self.normalize_f0
self.normalize_f0 = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
f0 = batch["f0"]
computed_data.append(f for f in f0)
pbar.update(batch_size)
self.normalize_f0 = normalize_f0
if self.normalize_f0:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
@ -583,70 +724,49 @@ class PitchExtractor:
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def normalize_pitch(self, pitch):
def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean
pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
def denormalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0
return pitch
@staticmethod
def load_or_compute_pitch(ap, wav_file, cache_path):
def compute_or_load(self, wav_file):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
_, wav_file, *_ = item
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
return pitch
return None
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
f0s = [item["f0"] for item in batch]
f0_lens = [len(item["f0"]) for item in batch]
f0_lens_max = max(f0_lens)
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
for i, f0_len in enumerate(f0_lens):
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r", encoding="utf-8") as ttf:
speaker_id = 0
for idx, line in enumerate(ttf):
# 2 samples per speaker to avoid eval split issues
if idx % 2 == 0:
speaker_id += 1
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"])
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
return items
@ -150,7 +154,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -165,7 +169,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -179,7 +183,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -193,7 +197,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -213,7 +217,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers:
continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
return items
@ -240,7 +244,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.append([text, wav_file, "LTTS_" + speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items
@ -259,7 +263,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file)
continue
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items
@ -281,12 +285,32 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append([text, wav_file, speaker_id])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
return items
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None):
"""VCTK dataset v0.92.
URL:
https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```.
It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset.
mic1:
Audio recorded using an omni-directional microphone (DPA 4035).
Contains very low frequency noises.
This is the same audio released in previous versions of VCTK:
https://doi.org/10.7488/ds/1994
mic2:
Audio recorded using a small diaphragm condenser microphone with
very wide bandwidth (Sennheiser MKH 800).
Two speakers, p280 and p315 had technical issues of the audio
recordings using MKH 800.
"""
file_ext = "flac"
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
@ -298,26 +322,33 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
# p280 has no mic2 recordings
if speaker_id == "p280":
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}")
else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file):
items.append([text, wav_file, "VCTK_" + speaker_id])
else:
print(f" [!] wav files don't exist - {wav_file}")
return items
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
def vctk_old(root_path, meta_files=None, wavs_path="wav48"):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for text_file in txt_files:
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([None, wav_file, "VCTK_" + speaker_id])
items.append([text, wav_file, "VCTK_old_" + speaker_id])
return items
@ -334,7 +365,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker in ignored_speakers:
continue
items.append([text, wav_file, "MLS_" + speaker])
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
return items
@ -404,7 +435,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name])
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
return items
@ -418,5 +449,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "")
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items

View File

@ -113,7 +113,7 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m ** 2)
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)

View File

@ -65,7 +65,7 @@ class WN(torch.nn.Module):
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate ** i
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding

View File

@ -101,7 +101,7 @@ class Encoder(nn.Module):
self.encoder_type = encoder_type
# embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder module
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:

View File

@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# relative positional encoding layers
if rel_attn_window_size is not None:
n_heads_rel = 1 if heads_share else num_heads
rel_stddev = self.k_channels ** -0.5
rel_stddev = self.k_channels**-0.5
emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
)
@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]

View File

@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
@ -588,13 +587,12 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
def forward(
self,
waveform,
waveform_hat,
mel_slice,
mel_slice_hat,
z_p,
logs_q,
m_p,
@ -610,8 +608,8 @@ class VitsGeneratorLoss(nn.Module):
):
"""
Shapes:
- waveform : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- mel_slice : :math:`[B, 1, T]`
- mel_slice_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]`
@ -624,23 +622,23 @@ class VitsGeneratorLoss(nn.Module):
loss = 0.0
return_dict = {}
z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_kl = (
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
* self.kl_loss_alpha
)
loss_feat = (
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
)
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss += loss_se
loss = loss + loss_se
return_dict["loss_spk_encoder"] = loss_se
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
@ -665,20 +663,24 @@ class VitsDiscriminatorLoss(nn.Module):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2)
fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_disc_real_{i}"] = ldr
return return_dict
@ -740,6 +742,7 @@ class ForwardTTSLoss(nn.Module):
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
binary_loss_weight=None,
):
loss = 0
return_dict = {}
@ -772,7 +775,12 @@ class ForwardTTSLoss(nn.Module):
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight:
return_dict["loss_binary_alignment"] = (
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
)
else:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss
return return_dict

View File

@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module):
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3)
# out = score * V

View File

@ -6,7 +6,6 @@ from .attentions import init_attn
from .common_layers import Linear, Prenet
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class ConvBNBlock(nn.Module):

View File

@ -57,7 +57,7 @@ class TextEncoder(nn.Module):
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
if language_emb_dim:
hidden_channels += language_emb_dim
@ -83,6 +83,7 @@ class TextEncoder(nn.Module):
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
assert x.shape[0] == x_lengths.shape[0]
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
# concat the lang emb in embedding chars
@ -90,7 +91,7 @@ class TextEncoder(nn.Module):
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module):
def forward(self, x, x_mask, g=None, reverse=False):
"""
Note:
Set `reverse` to True for inference.
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module):
def forward(self, x, x_mask, g=None, reverse=False):
"""
Note:
Set `reverse` to True for inference.
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`

View File

@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module):
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(num_layers):
dilation = kernel_size ** i
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module):
# posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q
)
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module):
z = torch.flip(z, [1])
# flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows))

View File

@ -1,52 +1,14 @@
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())
else:
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model
if config.characters is not None:
# set characters from config
if hasattr(MyModel, "make_symbols"):
symbols = MyModel.make_symbols(config)
else:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
if config.use_phonemes:
symbols = phonemes
# use default characters and assign them to config
config.characters = parse_symbols()
# consider special `blank` character if `add_blank` is set True
num_chars = len(symbols) + getattr(config, "add_blank", False)
config.num_chars = num_chars
# compatibility fix
if "model_params" in config:
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)
model = MyModel.init_from_config(config, samples)
return model
# TODO; class registery
# def import_models(models_dir, namespace):
# for file in os.listdir(models_dir):
# path = os.path.join(models_dir, file)
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
# model_name = file[: file.find(".py")] if file.endswith(".py") else file
# importlib.import_module(namespace + "." + model_name)
#
#
## automatically import any Python files in the models/ directory
# models_dir = os.path.dirname(__file__)
# import_models(models_dir, "TTS.tts.models")

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union
import torch
from coqpit import Coqpit
@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -100,11 +102,16 @@ class AlignTTS(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: "AlignTTSConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config)
super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
self.config = config
self.phase = -1
self.length_scale = (
float(config.model_args.length_scale)
@ -112,10 +119,6 @@ class AlignTTS(BaseTTS):
else config.model_args.length_scale
)
if not self.config.model_args.num_chars:
_, self.config, num_chars = self.get_characters(config)
self.config.model_args.num_chars = num_chars
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
self.embedded_speaker_dim = 0
@ -382,19 +385,17 @@ class AlignTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self, config, checkpoint_path, eval=False
@ -430,3 +431,19 @@ class AlignTTS(BaseTTS):
def on_epoch_start(self, trainer):
"""Set AlignTTS training phase on epoch start."""
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
@staticmethod
def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (AlignTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return AlignTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -9,6 +9,8 @@ from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
def __init__(self, config: Coqpit):
super().__init__(config)
def __init__(
self,
config: "TacotronConfig",
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields as class attributes
for key in config:
@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS):
"""Get the model criterion used in training."""
return TacotronLoss(self.config)
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
#############################
# COMMON COMPUTE FUNCTIONS
#############################

View File

@ -1,6 +1,6 @@
import os
import random
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@ -9,33 +9,44 @@ from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
class BaseTTS(BaseModel):
class BaseTTS(BaseTrainerModel):
"""Base `tts` class. Every new `tts` model must inherit this.
It defines common `tts` specific functions on top of `Model` implementation.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__()
self.config = config
self.ap = ap
self.tokenizer = tokenizer
self.speaker_manager = speaker_manager
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
@ -44,8 +55,11 @@ class BaseTTS(BaseModel):
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
config_num_chars = (
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
)
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
@ -58,22 +72,6 @@ class BaseTTS(BaseModel):
else:
raise ValueError("config must be either a *Config or *Args")
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = CharactersConfig(**parse_symbols())
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
@ -170,8 +168,8 @@ class BaseTTS(BaseModel):
Dict: [description]
"""
# setup input batch
text_input = batch["text"]
text_lengths = batch["text_lengths"]
text_input = batch["token_id"]
text_lengths = batch["token_id_lengths"]
speaker_names = batch["speaker_names"]
linear_input = batch["linear"]
mel_input = batch["mel"]
@ -239,7 +237,7 @@ class BaseTTS(BaseModel):
config: Coqpit,
assets: Dict,
is_eval: bool,
data_items: List,
samples: Union[List[Dict], List[List]],
verbose: bool,
num_gpus: int,
rank: int = None,
@ -247,8 +245,6 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval:
loader = None
else:
ap = assets["audio_processor"]
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
@ -264,12 +260,8 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None
d_vector_mapping = None
# setup custom symbols if needed
custom_symbols = None
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
if hasattr(self, "language_manager"):
# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
@ -279,74 +271,40 @@ class BaseTTS(BaseModel):
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items,
ap=ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
min_text_len=config.min_text_len,
max_text_len=config.max_text_len,
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
use_phonemes=config.use_phonemes,
phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping,
)
# pre-compute phonemes
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
# halt DDP processes for the main process to finish computing the phoneme cache
# wait all the DDP process to be ready
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
dataset.preprocess_samples()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
@ -357,17 +315,17 @@ class BaseTTS(BaseModel):
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items)
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)
sampler = get_speaker_weighted_sampler(dataset.samples)
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False,
shuffle=False, # shuffle is done in the dataset.
collate_fn=dataset.collate_fn,
drop_last=False,
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
@ -403,7 +361,6 @@ class BaseTTS(BaseModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -415,17 +372,15 @@ class BaseTTS(BaseModel):
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, Tuple
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@ -14,7 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
@dataclass
@ -170,17 +171,22 @@ class ForwardTTS(BaseTTS):
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
self._set_model_args(config)
super().__init__(config)
self.speaker_manager = speaker_manager
self.init_multispeaker(config)
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch
self.use_binary_alignment_loss = False
self.binary_loss_weight = 0.0
self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
@ -255,7 +261,7 @@ class ForwardTTS(BaseTTS):
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels)
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod
@ -638,8 +644,9 @@ class ForwardTTS(BaseTTS):
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
alignment_soft=outputs["alignment_soft"],
alignment_hard=outputs["alignment_mas"],
binary_loss_weight=self.binary_loss_weight,
)
# compute duration error
durations_pred = outputs["durations"]
@ -666,17 +673,12 @@ class ForwardTTS(BaseTTS):
# plot pitch figures
if self.args.use_pitch:
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
pitch_figures = {
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
}
figures.update(pitch_figures)
@ -692,19 +694,17 @@ class ForwardTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self, config, checkpoint_path, eval=False
@ -721,6 +721,21 @@ class ForwardTTS(BaseTTS):
return ForwardTTSLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
"""Schedule binary loss weight."""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
@staticmethod
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (ForwardTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,5 +1,5 @@
import math
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -39,18 +40,31 @@ class GlowTTS(BaseTTS):
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples:
Init only model layers.
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig(num_chars=2)
>>> model = GlowTTS(config)
Fully init a model ready for action. All the class attributes and class members
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS(config)
>>> model = GlowTTS.init_from_config(config, verbose=False)
"""
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: GlowTTSConfig,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config)
self.speaker_manager = speaker_manager
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields to `self`
# for fewer code change
@ -58,7 +72,6 @@ class GlowTTS(BaseTTS):
for key in config:
setattr(self, key, config[key])
_, self.config, self.num_chars = self.get_characters(config)
self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary
@ -94,25 +107,25 @@ class GlowTTS(BaseTTS):
def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets
speaker embedding vector dimension to the d-vector dimension from the config.
Args:
config (Coqpit): Model configuration.
"""
self.embedded_speaker_dim = 0
# init speaker manager
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
if config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
if self.speaker_manager is not None:
assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
@ -170,6 +183,8 @@ class GlowTTS(BaseTTS):
if g is not None:
if hasattr(self, "emb_g"):
# use speaker embedding layer
if not g.size(): # if is a scalar
g = g.unsqueeze(0) # unsqueeze
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
else:
# use d-vector
@ -180,12 +195,33 @@ class GlowTTS(BaseTTS):
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
- x: :math:`[B, T]`
- x_lenghts::math:`B`
- y: :math:`[B, T, C]`
- y_lengths::math:`B`
- g: :math:`[B, C] or B`
Args:
x (torch.Tensor):
Input text sequence ids. :math:`[B, T_en]`
x_lengths (torch.Tensor):
Lengths of input text sequences. :math:`[B]`
y (torch.Tensor):
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
y_lengths (torch.Tensor):
Lengths of target mel-spectrogram frames. :math:`[B]`
aux_input (Dict):
Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model.
:math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding
layer. :math:`B`
Returns:
Dict:
- z: :math: `[B, T_de, C]`
- logdet: :math:`B`
- y_mean: :math:`[B, T_de, C]`
- y_log_scale: :math:`[B, T_de, C]`
- alignments: :math:`[B, T_en, T_de]`
- durations_log: :math:`[B, T_en, 1]`
- total_durations_log: :math:`[B, T_en, 1]`
"""
# [B, T, C] -> [B, C, T]
y = y.transpose(1, 2)
@ -206,9 +242,9 @@ class GlowTTS(BaseTTS):
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
@ -255,9 +291,9 @@ class GlowTTS(BaseTTS):
# find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
@ -422,20 +458,18 @@ class GlowTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
@ -446,7 +480,6 @@ class GlowTTS(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -461,18 +494,16 @@ class GlowTTS(BaseTTS):
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs["outputs"]["model_outputs"], ap, output_fig=False
outputs["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios
@ -499,7 +530,8 @@ class GlowTTS(BaseTTS):
self.store_inverse()
assert not self.training
def get_criterion(self):
@staticmethod
def get_criterion():
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
return GlowTTSLoss()
@ -507,3 +539,20 @@ class GlowTTS(BaseTTS):
def on_train_step_start(self, trainer):
"""Decide on every training step wheter enable/disable data depended initialization."""
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,7 +1,8 @@
# coding: utf-8
from typing import Dict, List, Union
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -24,12 +26,15 @@ class Tacotron(BaseTacotron):
a multi-speaker model. Defaults to None.
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
def __init__(
self,
config: "TacotronConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = self.num_chars = len(chars)
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields to `self`
# for fewer code change
@ -302,16 +307,30 @@ class Tacotron(BaseTacotron):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (TacotronConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return Tacotron(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,9 +1,8 @@
# coding: utf-8
from typing import Dict
from typing import Dict, List, Union
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
def __init__(
self,
config: "Tacotron2Config",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
# pass all config fields to `self`
@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron):
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
"""Log training progress."""
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (Tacotron2Config): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(new_config, samples)
return Tacotron2(new_config, ap, tokenizer, speaker_manager)

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +0,0 @@
## Utilities to Convert Models to Tensorflow2
Here there are experimental utilities to convert trained Torch models to Tensorflow (2.2>=).
Converting Torch models to TF enables all the TF toolkit to be used for better deployment and device specific optimizations.
Note that we do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome.
To see how you can use TF model at inference, check the notebook.
This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own.
### Converting a Model
- Run ```convert_tacotron2_torch_to_tf.py --torch_model_path /path/to/torch/model.pth.tar --config_path /path/to/model/config.json --output_path /path/to/output/tf/model``` with the right arguments.
### Known issues ans limitations
- We use a custom model load/save mechanism which enables us to store model related information with models weights. (Similar to Torch). However, it is prone to random errors.
- Current TF model implementation is slightly slower than Torch model. Hopefully, it'll get better with improving TF support for eager mode and ```tf.function```.
- TF implementation of Tacotron2 only supports regular Tacotron2 as in the paper.
- You can only convert models trained after TF model implementation since model layers has been updated in Torch model.

View File

@ -1,301 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class Linear(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.activation = keras.layers.ReLU()
def call(self, x):
"""
shapes:
x: B x T x C
"""
return self.activation(self.linear_layer(x))
class LinearBN(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.batch_normalization = keras.layers.BatchNormalization(
axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.activation = keras.layers.ReLU()
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
out = self.linear_layer(x)
out = self.batch_normalization(out, training=training)
return self.activation(out)
class Prenet(keras.layers.Layer):
def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs):
super().__init__(**kwargs)
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.linear_layers = []
if prenet_type == "bn":
self.linear_layers += [
LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
elif prenet_type == "original":
self.linear_layers += [
Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
else:
raise RuntimeError(" [!] Unknown prenet type.")
if prenet_dropout:
self.dropout = keras.layers.Dropout(rate=0.5)
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
for linear in self.linear_layers:
if self.prenet_dropout:
x = self.dropout(linear(x), training=training)
else:
x = linear(x)
return x
def _sigmoid_norm(score):
attn_weights = tf.nn.sigmoid(score)
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
return attn_weights
class Attention(keras.layers.Layer):
"""TODO: implement forward_attention
TODO: location sensitive attention
TODO: implement attention windowing"""
def __init__(
self,
attn_dim,
use_loc_attn,
loc_attn_n_filters,
loc_attn_kernel_size,
use_windowing,
norm,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
**kwargs,
):
super().__init__(**kwargs)
self.use_loc_attn = use_loc_attn
self.loc_attn_n_filters = loc_attn_n_filters
self.loc_attn_kernel_size = loc_attn_kernel_size
self.use_windowing = use_windowing
self.norm = norm
self.use_forward_attn = use_forward_attn
self.use_trans_agent = use_trans_agent
self.use_forward_attn_mask = use_forward_attn_mask
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer")
self.inputs_layer = tf.keras.layers.Dense(
attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer"
)
self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer")
if use_loc_attn:
self.location_conv1d = keras.layers.Conv1D(
filters=loc_attn_n_filters,
kernel_size=loc_attn_kernel_size,
padding="same",
use_bias=False,
name="location_layer/location_conv1d",
)
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense")
if norm == "softmax":
self.norm_func = tf.nn.softmax
elif norm == "sigmoid":
self.norm_func = _sigmoid_norm
else:
raise ValueError("Unknown value for attention norm type")
def init_states(self, batch_size, value_length):
states = []
if self.use_loc_attn:
attention_cum = tf.zeros([batch_size, value_length])
attention_old = tf.zeros([batch_size, value_length])
states = [attention_cum, attention_old]
if self.use_forward_attn:
alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1)
states.append(alpha)
return tuple(states)
def process_values(self, values):
"""cache values for decoder iterations"""
# pylint: disable=attribute-defined-outside-init
self.processed_values = self.inputs_layer(values)
self.values = values
def get_loc_attn(self, query, states):
"""compute location attention, query layer and
unnorm. attention weights"""
attention_cum, attention_old = states[:2]
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1))
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn))
score = tf.squeeze(score, axis=2)
return score, processed_query
def get_attn(self, query):
"""compute query layer and unnormalized attention weights"""
processed_query = self.query_layer(tf.expand_dims(query, 1))
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
score = tf.squeeze(score, axis=2)
return score, processed_query
def apply_score_masking(self, score, mask): # pylint: disable=no-self-use
"""ignore sequence paddings"""
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# Bias so padding positions do not contribute to attention distribution.
score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use
# forward attention
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
# compute transition potentials
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
# renormalize attention weights
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
return new_alpha
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
states = []
if self.use_loc_attn:
states = [old_states[0] + scores_norm, attn_weights]
if self.use_forward_attn:
states.append(new_alpha)
return tuple(states)
def call(self, query, states):
"""
shapes:
query: B x D
"""
if self.use_loc_attn:
score, _ = self.get_loc_attn(query, states)
else:
score, _ = self.get_attn(query)
# TODO: masking
# if mask is not None:
# self.apply_score_masking(score, mask)
# attn_weights shape == (batch_size, max_length, 1)
# normalize attention scores
scores_norm = self.norm_func(score)
attn_weights = scores_norm
# apply forward attention
new_alpha = None
if self.use_forward_attn:
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
attn_weights = new_alpha
# update states tuple
# states = (cum_attn_weights, attn_weights, new_alpha)
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = tf.matmul(
tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False
)
context_vector = tf.squeeze(context_vector, axis=1)
return context_vector, attn_weights, states
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
# dtype = processed_query.dtype
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
# class LocationSensitiveAttention(BahdanauAttention):
# def __init__(self,
# units,
# memory=None,
# memory_sequence_length=None,
# normalize=False,
# probability_fn="softmax",
# kernel_initializer="glorot_uniform",
# dtype=None,
# name="LocationSensitiveAttention",
# location_attention_filters=32,
# location_attention_kernel_size=31):
# super( self).__init__(units=units,
# memory=memory,
# memory_sequence_length=memory_sequence_length,
# normalize=normalize,
# probability_fn='softmax', ## parent module default
# kernel_initializer=kernel_initializer,
# dtype=dtype,
# name=name)
# if probability_fn == 'sigmoid':
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
# self.location_dense = keras.layers.Dense(units, use_bias=False)
# # self.v = keras.layers.Dense(1, use_bias=True)
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
# processed_query = tf.expand_dims(processed_query, 1)
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
# def _location_sensitive(self, alignment_cum, alignment_old):
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
# return self.location_dense(self.location_conv(alignment_cat))
# def _sigmoid_normalization(self, score):
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
# # def _apply_masking(self, score, mask):
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# # # Bias so padding positions do not contribute to attention distribution.
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
# # return score
# def _calculate_attention(self, query, state):
# alignment_cum, alignment_old = state[:2]
# processed_query = self.query_layer(
# query) if self.query_layer else query
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
# score = self._location_sensitive_score(
# processed_query,
# self.keys,
# processed_loc)
# alignment = self.probability_fn(score, state)
# alignment_cum = alignment_cum + alignment
# state[0] = alignment_cum
# state[1] = alignment
# return alignment, state
# def compute_context(self, alignments):
# expanded_alignments = tf.expand_dims(alignments, 1)
# context = tf.matmul(expanded_alignments, self.values)
# context = tf.squeeze(context, [1])
# return context
# # def call(self, query, state):
# # alignment, next_state = self._calculate_attention(query, state)
# # return alignment, next_state

View File

@ -1,322 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet
from TTS.tts.tf.utils.tf_utils import shape_list
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class ConvBNBlock(keras.layers.Layer):
def __init__(self, filters, kernel_size, activation, **kwargs):
super().__init__(**kwargs)
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d")
self.batch_normalization = keras.layers.BatchNormalization(
axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.dropout = keras.layers.Dropout(rate=0.5, name="dropout")
self.activation = keras.layers.Activation(activation, name="activation")
def call(self, x, training=None):
o = self.convolution1d(x)
o = self.batch_normalization(o, training=training)
o = self.activation(o)
o = self.dropout(o, training=training)
return o
class Postnet(keras.layers.Layer):
def __init__(self, output_filters, num_convs, **kwargs):
super().__init__(**kwargs)
self.convolutions = []
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0"))
for idx in range(1, num_convs - 1):
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}"))
self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}"))
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
return o
class Encoder(keras.layers.Layer):
def __init__(self, output_input_dim, **kwargs):
super().__init__(**kwargs)
self.convolutions = []
for idx in range(3):
self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}"))
self.lstm = keras.layers.Bidirectional(
keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm"
)
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
o = self.lstm(o)
return o
class Decoder(keras.layers.Layer):
# pylint: disable=unused-argument
def __init__(
self,
frame_dim,
r,
attn_type,
use_attn_win,
attn_norm,
prenet_type,
prenet_dropout,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
use_location_attn,
attn_K,
separate_stopnet,
speaker_emb_dim,
enable_tflite,
**kwargs,
):
super().__init__(**kwargs)
self.frame_dim = frame_dim
self.r_init = tf.constant(r, dtype=tf.int32)
self.r = tf.constant(r, dtype=tf.int32)
self.output_dim = r * self.frame_dim
self.separate_stopnet = separate_stopnet
self.enable_tflite = enable_tflite
# layer constants
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
# model dimensions
self.query_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.attn_dim = 128
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet")
self.attention_rnn = keras.layers.LSTMCell(
self.query_dim,
use_bias=True,
name="attention_rnn",
)
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
self.attention = Attention(
attn_dim=self.attn_dim,
use_loc_attn=True,
loc_attn_n_filters=32,
loc_attn_kernel_size=31,
use_windowing=False,
norm=attn_norm,
use_forward_attn=use_forward_attn,
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name="attention",
)
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn")
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer")
self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer")
def set_max_decoder_steps(self, new_max_steps):
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
def set_r(self, new_r):
self.r = tf.constant(new_r, dtype=tf.int32)
self.output_dim = self.frame_dim * new_r
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
zero_frame = tf.zeros([batch_size, self.frame_dim])
zero_context = tf.zeros([batch_size, memory_dim])
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
attention_states = self.attention.init_states(batch_size, memory_length)
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
def step(self, prenet_next, states, memory_seq_length=None, training=None):
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
attention_rnn_output, attention_rnn_state = self.attention_rnn(
attention_rnn_input, attention_rnn_state, training=training
)
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
decoder_rnn_output, decoder_rnn_state = self.decoder_rnn(
decoder_rnn_input, decoder_rnn_state, training=training
)
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
output_frame = self.linear_projection(linear_projection_input, training=training)
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
stopnet_output = self.stopnet(stopnet_input, training=training)
output_frame = output_frame[:, : self.r * self.frame_dim]
states = (
output_frame[:, self.frame_dim * (self.r - 1) :],
context,
attention_rnn_state,
decoder_rnn_state,
attention_states,
)
return output_frame, stopnet_output, states, attention
def decode(self, memory, states, frames, memory_seq_length=None):
B, _, _ = shape_list(memory)
num_iter = shape_list(frames)[1] // self.r
# init states
frame_zero = tf.expand_dims(states[0], 1)
frames = tf.concat([frame_zero, frames], axis=1)
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
# pre-computes
self.attention.process_values(memory)
prenet_output = self.prenet(frames, training=True)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
prenet_next = prenet_output[:, step]
output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
_, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
lambda *arg: True,
_body,
loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=num_iter,
)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def decode_inference(self, memory, states):
B, _, _ = shape_list(memory)
# init states
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
# pre-computes
self.attention.process_values(memory)
# iter vars
stop_flag = tf.constant(False, dtype=tf.bool)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, attention = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
_, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def decode_inference_tflite(self, memory, states):
"""Inference with TF-Lite compatibility. It assumes
batch_size is 1"""
# init states
# dynamic_shape is not supported in TFLite
outputs = tf.TensorArray(
dtype=tf.float32,
size=self.max_decoder_steps,
element_shape=tf.TensorShape([self.output_dim]),
clear_after_read=False,
dynamic_size=False,
)
# stop_flags = tf.TensorArray(dtype=tf.bool,
# size=self.max_decoder_steps,
# element_shape=tf.TensorShape(
# []),
# clear_after_read=False,
# dynamic_size=False)
attentions = ()
stop_tokens = ()
# pre-computes
self.attention.process_values(memory)
# iter vars
stop_flag = tf.constant(False, dtype=tf.bool)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, states, outputs, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, _ = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
outputs = outputs.write(step, tf.reshape(output, [-1]))
return step + 1, memory, states, outputs, stop_flag
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
step_count, memory, states, outputs, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
outputs = tf.expand_dims(outputs, axis=[0])
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
if training:
return self.decode(memory, states, frames, memory_seq_length)
if self.enable_tflite:
return self.decode_inference_tflite(memory, states)
return self.decode_inference(memory, states)

View File

@ -1,116 +0,0 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.tf.utils.tf_utils import shape_list
# pylint: disable=too-many-ancestors, abstract-method
class Tacotron2(keras.models.Model):
def __init__(
self,
num_chars,
num_speakers,
r,
out_channels=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False,
enable_tflite=False,
):
super().__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.out_channels = out_channels
self.bidirectional_decoder = bidirectional_decoder
self.num_speakers = num_speakers
self.speaker_embed_dim = 256
self.enable_tflite = enable_tflite
self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding")
self.encoder = Encoder(512, name="encoder")
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(
decoder_output_dim,
r,
attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim,
name="decoder",
enable_tflite=enable_tflite,
)
self.postnet = Postnet(out_channels, 5, name="postnet")
@tf.function(experimental_relax_shapes=True)
def call(self, characters, text_lengths=None, frames=None, training=None):
if training:
return self.training(characters, text_lengths, frames)
if not training:
return self.inference(characters)
raise RuntimeError(" [!] Set model training mode True or False")
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(
encoder_output, decoder_states, frames, text_lengths, training=True
)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
def inference(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
@tf.function(
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec([1, None], dtype=tf.int32),
],
)
def inference_tflite(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(
self,
):
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg
self(input_ids)

View File

@ -1,87 +0,0 @@
import numpy as np
import tensorflow as tf
# NOTE: linter has a problem with the current TF release
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
def tf_create_dummy_inputs():
"""Create dummy inputs for TF Tacotron2 model"""
batch_size = 4
max_input_length = 32
max_mel_length = 128
pad = 1
n_chars = 24
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size])
input_lengths[-1] = max_input_length
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size])
mel_lengths[-1] = max_mel_length
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
return input_ids, input_lengths, mel_outputs, mel_lengths
def compare_torch_tf(torch_tensor, tf_tensor):
"""Compute the average absolute difference b/w torch and tf tensors"""
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
def convert_tf_name(tf_name):
"""Convert certain patterns in TF layer names to Torch patterns"""
tf_name_tmp = tf_name
tf_name_tmp = tf_name_tmp.replace(":0", "")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1")
tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh")
tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight")
tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight")
tf_name_tmp = tf_name_tmp.replace("/beta", "/bias")
tf_name_tmp = tf_name_tmp.replace("/", ".")
return tf_name_tmp
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
"""Transfer weigths from torch state_dict to TF variables"""
print(" > Passing weights from Torch to TF ...")
for tf_var in tf_vars:
torch_var_name = var_map_dict[tf_var.name]
print(f" | > {tf_var.name} <-- {torch_var_name}")
# if tuple, it is a bias variable
if not isinstance(torch_var_name, tuple):
torch_layer_name = ".".join(torch_var_name.split(".")[-2:])
torch_weight = state_dict[torch_var_name]
if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name:
# out_dim, in_dim, filter -> filter, in_dim, out_dim
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
# if variable is for bidirectional lstm and it is a bias vector there
# needs to be pre-defined two matching torch bias vectors
elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif "rnn" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
elif "rnn" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif "linear_layer" in torch_layer_name and "weight" in torch_var_name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
else:
numpy_weight = torch_weight.detach().cpu().numpy()
assert np.all(
tf_var.shape == numpy_weight.shape
), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
tf.keras.backend.set_value(tf_var, numpy_weight)
return tf_vars
def load_tf_vars(model_tf, tf_vars):
for tf_var in tf_vars:
model_tf.get_layer(tf_var.name).set_weights(tf_var)
return model_tf

View File

@ -1,105 +0,0 @@
import datetime
import importlib
import pickle
import fsspec
import numpy as np
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.max()
batch_size = sequence_length.size(0)
seq_range = np.empty([0, max_len], dtype=np.int8)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = seq_range_expand.type_as(sequence_length)
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
# B x T_max
return seq_range_expand < seq_length_expand
# @tf.custom_gradient
def check_gradient(x, grad_clip):
x_normed = tf.clip_by_norm(x, grad_clip)
grad_norm = tf.norm(grad_clip)
return x_normed, grad_norm
def count_parameters(model, c):
try:
return model.count_params()
except RuntimeError:
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
input_lengths = np.random.randint(100, 129, (8,))
input_lengths[-1] = 128
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
mel_spec = tf.convert_to_tensor(mel_spec)
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
return model.count_params()
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
raise NotImplementedError(" [!] Tacotron model is not ready.")
# tacotron2
model = MyModel(
num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
out_channels=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
enable_tflite=enable_tflite,
)
return model

View File

@ -1,45 +0,0 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
def load_tflite_model(tflite_path):
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
tflite_model.allocate_tensors()
return tflite_model

View File

@ -1,8 +0,0 @@
import tensorflow as tf
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]

View File

@ -1,27 +0,0 @@
import fsspec
import tensorflow as tf
def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True):
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
provided, else return TFLite model."""
concrete_function = model.inference_tflite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
converter.experimental_new_converter = experimental_converter
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model
def load_tflite_model(tflite_path):
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
tflite_model.allocate_tensors()
return tflite_model

View File

@ -57,40 +57,65 @@ def sequence_mask(sequence_length, max_len=None):
return mask
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
"""Segment each sample in a batch based on the provided segment indices
Args:
x (torch.tensor): Input tensor.
segment_indices (torch.tensor): Segment indices.
segment_size (int): Expected output segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
"""
# pad the input tensor if it is shorter than the segment size
if pad_short and x.shape[-1] < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
x_i = x[i]
if pad_short and index_end > x.size(2):
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
def rand_segments(
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
):
"""Create random segments based on the input lengths.
Args:
x (torch.tensor): Input tensor.
x_lengths (torch.tensor): Input lengths.
segment_size (int): Expected output segment size.
let_short_samples (bool): Allow shorter samples than the segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B]`
"""
_x_lenghts = x_lengths.clone()
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
if pad_short:
if T < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - T))
T = segment_size
if _x_lenghts is None:
_x_lenghts = T
len_diff = _x_lenghts - segment_size + 1
if let_short_samples:
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
else:
assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices

View File

@ -8,6 +8,8 @@ import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import check_config_and_model_args
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
@ -98,6 +100,20 @@ class LanguageManager:
"""
self._save_json(file_path, self.language_id_mapping)
@staticmethod
def init_from_config(config: Coqpit) -> "LanguageManager":
"""Initialize the language manager from a Coqpit config.
Args:
config (Coqpit): Coqpit config.
"""
language_manager = None
if check_config_and_model_args(config, "use_language_embedding", True):
if config.get("language_ids_file", None):
language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
language_manager = LanguageManager(config=config)
return language_manager
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.
@ -113,7 +129,7 @@ def _set_file_path(path):
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])

View File

@ -9,7 +9,7 @@ import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import load_config
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
@ -118,7 +118,7 @@ class SpeakerManager:
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item[2] for item in items})
speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
@ -318,6 +318,42 @@ class SpeakerManager:
# TODO: implement speaker encoder
raise NotImplementedError
@staticmethod
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
"""Initialize a speaker manager from config
Args:
config (Coqpit): Config object.
samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names.
Defaults to None.
Returns:
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None)
)
if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False):
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "d_vector_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None)
)
return speaker_manager
def _set_file_path(path):
"""Find the speakers.json under the given path or the above it.
@ -414,7 +450,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items])
speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])

View File

@ -8,7 +8,7 @@ from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)])
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
return gauss / gauss.sum()
@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

View File

@ -1,46 +1,9 @@
import os
from typing import Dict
import numpy as np
import pkg_resources
import torch
from torch import nn
from .text import phoneme_to_sequence, text_to_sequence
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable
if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(
text,
text_cleaner,
language if language else CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
),
dtype=np.int32,
)
else:
seq = np.asarray(
text_to_sequence(
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
),
dtype=np.int32,
)
return seq
def numpy_to_torch(np_array, dtype, cuda=False):
if np_array is None:
@ -51,13 +14,6 @@ def numpy_to_torch(np_array, dtype, cuda=False):
return tensor
def numpy_to_tf(np_array, dtype):
if np_array is None:
return None
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
return tensor
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda:
@ -103,53 +59,6 @@ def run_model_torch(
return outputs
def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TF")
if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
# TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
# get input and output details
input_details = model.get_input_details()
output_details = model.get_output_details()
# reshape input tensor for the new input shape
model.resize_tensor_input(input_details[0]["index"], inputs.shape)
model.allocate_tensors()
detail = input_details[0]
# input_shape = detail['shape']
model.set_tensor(detail["index"], inputs)
# run the model
model.invoke()
# collect outputs
decoder_output = model.get_tensor(output_details[0]["index"])
postnet_output = model.get_tensor(output_details[1]["index"])
# tflite model only returns feature frames
return decoder_output, postnet_output, None, None
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].numpy()
decoder_output = decoder_output[0].numpy()
alignment = alignments[0].numpy()
stop_tokens = stop_tokens[0].numpy()
return postnet_output, decoder_output, alignment, stop_tokens
def parse_outputs_tflite(postnet_output, decoder_output):
postnet_output = postnet_output[0]
decoder_output = decoder_output[0]
return postnet_output, decoder_output
def trim_silence(wav, ap):
return wav[: ap.find_endpoint(wav)]
@ -204,16 +113,12 @@ def synthesis(
text,
CONFIG,
use_cuda,
ap,
speaker_id=None,
style_wav=None,
enable_eos_bos_chars=False, # pylint: disable=unused-argument
use_griffin_lim=False,
do_trim_silence=False,
d_vector=None,
language_id=None,
language_name=None,
backend="torch",
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -231,9 +136,6 @@ def synthesis(
use_cuda (bool):
Enable/disable CUDA.
ap (TTS.tts.utils.audio.AudioProcessor):
The audio processor for extracting features and pre/post-processing audio.
speaker_id (int):
Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.
@ -251,74 +153,51 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
language_name (str):
Language name corresponding to the language code used by the phonemizer. Defaults to None.
backend (str):
tf or torch. Defaults to "torch".
"""
# GST processing
style_mel = None
custom_symbols = None
if style_wav:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
elif CONFIG.has("gst") and CONFIG.gst and not style_wav:
if CONFIG.gst.gst_style_input_weights:
style_mel = CONFIG.gst.gst_style_input_weights
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),
dtype=np.int32,
)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
elif backend in ["tf", "tflite"]:
# TODO: handle speaker id for tf model
style_mel = numpy_to_tf(style_mel, tf.float32)
text_inputs = numpy_to_tf(text_inputs, tf.int32)
text_inputs = tf.expand_dims(text_inputs, 0)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
if backend == "torch":
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
# convert outputs to numpy
# plot results
wav = None
if hasattr(model, "END2END") and model.END2END:
wav = model_outputs.squeeze(0)
else:
model_outputs = model_outputs.squeeze()
if model_outputs.ndim == 2: # [T, C_spec]
if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG)
wav = inv_spectrogram(model_outputs, model.ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
wav = trim_silence(wav, model.ap)
else: # [T,]
wav = model_outputs
return_dict = {
"wav": wav,
"alignments": alignments,

View File

@ -1,276 +1 @@
# -*- coding: utf-8 -*-
# adapted from https://github.com/keithito/tacotron
import re
from typing import Dict, List
import gruut
from gruut_ipa import IPA
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols
# pylint: disable=unnecessary-comprehension
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
# Regular expression matching punctuations, ignoring empty space
PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
# Table for str.translate to fix gruut/TTS phoneme mismatch
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False):
"""Convert graphemes to phonemes.
Parameters:
text (str): text to phonemize
language (str): language of the text
Returns:
ph (str): phonemes as a string seperated by "|"
ph = "ɪ|g|ˈ|z|æ|m|p|ə|l"
"""
# TO REVIEW : How to have a good implementation for this?
if language == "zh-CN":
ph = chinese_text_to_phonemes(text)
return ph
if language == "ja-jp":
ph = japanese_text_to_phonemes(text)
return ph
if not gruut.is_language_supported(language):
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
# Use gruut for phonemization
ph_list = []
for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes):
for word in sentence:
if word.is_break:
# Use actual character for break phoneme (e.g., comma)
if ph_list:
# Join with previous word
ph_list[-1].append(word.text)
else:
# First word is punctuation
ph_list.append([word.text])
elif word.phonemes:
# Add phonemes for word
word_phonemes = []
for word_phoneme in word.phonemes:
if not keep_stress:
# Remove primary/secondary stress
word_phoneme = IPA.without_stress(word_phoneme)
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
if word_phoneme:
# Flatten phonemes
word_phonemes.extend(word_phoneme)
if word_phonemes:
ph_list.append(word_phonemes)
# Join and re-split to break apart dipthongs, suprasegmentals, etc.
ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list]
ph = "| ".join(ph_words)
return ph
def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def pad_with_eos_bos(phoneme_sequence, tp=None):
# pylint: disable=global-statement
global _phonemes_to_id, _bos, _eos
if tp:
_bos = tp["bos"]
_eos = tp["eos"]
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(
text: str,
cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _phonemes_to_id, _phonemes
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
to_phonemes = text2phone(clean_text, language, use_espeak_phonemes=use_espeak_phonemes)
if to_phonemes is None:
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
for phoneme in filter(None, to_phonemes.split("|")):
sequence += _phoneme_to_sequence(phoneme)
# Append EOS char
if enable_eos_bos:
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
# pylint: disable=global-statement
"""Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = ""
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence:
if symbol_id in _id_to_phonemes:
s = _id_to_phonemes[symbol_id]
result += s
return result.replace("}{", " ")
def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _symbol_to_id, _symbols
if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while text:
m = _CURLY_RE.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
if add_blank:
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
return sequence
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
"""Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement
global _id_to_symbol, _symbols
if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence))
if custom_symbols is not None:
_symbols = custom_symbols
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
elif tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def _symbols_to_sequence(syms):
return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
def _phoneme_to_sequence(phons):
return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s not in ["~", "^", "_"]
def _should_keep_phoneme(p):
return p in _phonemes_to_id and p not in ["~", "^", "_"]
from TTS.tts.utils.text.tokenizer import TTSTokenizer

View File

@ -0,0 +1,468 @@
from dataclasses import replace
from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
def parse_symbols():
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
# DEFAULT SET OF GRAPHEMES
_pad = "<PAD>"
_eos = "<EOS>"
_bos = "<BOS>"
_blank = "<BLNK>" # TODO: check if we need this alongside with PAD
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_punctuations = "!'(),-.:;? "
# DEFAULT SET OF IPA PHONEMES
# Phonemes definition (All IPA characters)
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
class BaseVocabulary:
"""Base Vocabulary class.
This class only needs a vocabulary dictionary without specifying the characters.
Args:
vocab (Dict): A dictionary of characters and their corresponding indices.
"""
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
self.vocab = vocab
self.pad = pad
self.blank = blank
self.bos = bos
self.eos = eos
@property
def pad_id(self) -> int:
"""Return the index of the padding character. If the padding character is not specified, return the length
of the vocabulary."""
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
"""Return the index of the blank character. If the blank character is not specified, return the length of
the vocabulary."""
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
@vocab.setter
def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod
def init_from_config(config, **kwargs):
"""Initialize from the given config."""
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
return (
BaseVocabulary(
config.characters.vocab_dict,
config.characters.pad,
config.characters.blank,
config.characters.bos,
config.characters.eos,
),
config,
)
return BaseVocabulary(**kwargs), config
@property
def num_chars(self):
"""Return number of tokens in the vocabulary."""
return len(self._vocab)
def char_to_id(self, char: str) -> int:
"""Map a character to an token ID."""
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
"""Map an token ID to a character."""
return self._id_to_char[idx]
class BaseCharacters:
"""🐸BaseCharacters class
Every new character class should inherit from this.
Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```.
If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method.
Args:
characters (str):
Main set of characters to be used in the vocabulary.
punctuations (str):
Characters to be treated as punctuation.
pad (str):
Special padding character that would be ignored by the model.
eos (str):
End of the sentence character.
bos (str):
Beginning of the sentence character.
blank (str):
Optional character used between characters by some models for better prosody.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
el
is_sorted (bool):
Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True.
"""
def __init__(
self,
characters: str = None,
punctuations: str = None,
pad: str = None,
eos: str = None,
bos: str = None,
blank: str = None,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
self._characters = characters
self._punctuations = punctuations
self._pad = pad
self._eos = eos
self._bos = bos
self._blank = blank
self.is_unique = is_unique
self.is_sorted = is_sorted
self._create_vocab()
@property
def pad_id(self) -> int:
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
@property
def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def characters(self):
return self._characters
@characters.setter
def characters(self, characters):
self._characters = characters
self._create_vocab()
@property
def punctuations(self):
return self._punctuations
@punctuations.setter
def punctuations(self, punctuations):
self._punctuations = punctuations
self._create_vocab()
@property
def pad(self):
return self._pad
@pad.setter
def pad(self, pad):
self._pad = pad
self._create_vocab()
@property
def eos(self):
return self._eos
@eos.setter
def eos(self, eos):
self._eos = eos
self._create_vocab()
@property
def bos(self):
return self._bos
@bos.setter
def bos(self, bos):
self._bos = bos
self._create_vocab()
@property
def blank(self):
return self._blank
@blank.setter
def blank(self, blank):
self._blank = blank
self._create_vocab()
@property
def vocab(self):
return self._vocab
@vocab.setter
def vocab(self, vocab):
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
@property
def num_chars(self):
return len(self._vocab)
def _create_vocab(self):
_vocab = self._characters
if self.is_unique:
_vocab = list(set(_vocab))
if self.is_sorted:
_vocab = sorted(_vocab)
_vocab = list(_vocab)
_vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
self.vocab = _vocab + list(self._punctuations)
if self.is_unique:
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
assert (
len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
), f" [!] There are duplicate characters in the character set. {duplicates}"
def char_to_id(self, char: str) -> int:
try:
return self._char_to_id[char]
except KeyError as e:
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
def id_to_char(self, idx: int) -> str:
return self._id_to_char[idx]
def print_log(self, level: int = 0):
"""
Prints the vocabulary in a nice format.
"""
indent = "\t" * level
print(f"{indent}| > Characters: {self._characters}")
print(f"{indent}| > Punctuations: {self._punctuations}")
print(f"{indent}| > Pad: {self._pad}")
print(f"{indent}| > EOS: {self._eos}")
print(f"{indent}| > BOS: {self._bos}")
print(f"{indent}| > Blank: {self._blank}")
print(f"{indent}| > Vocab: {self.vocab}")
print(f"{indent}| > Num chars: {self.num_chars}")
@staticmethod
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument
"""Init your character class from a config.
Implement this method for your subclass.
"""
# use character set from config
if config.characters is not None:
return BaseCharacters(**config.characters), config
# return default character set
characters = BaseCharacters()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
characters=self._characters,
punctuations=self._punctuations,
pad=self._pad,
eos=self._eos,
bos=self._bos,
blank=self._blank,
is_unique=self.is_unique,
is_sorted=self.is_sorted,
)
class IPAPhonemes(BaseCharacters):
"""🐸IPAPhonemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using IPAPhonemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
blank (str):
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _phonemes,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
blank: str = _blank,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
"""Init a IPAPhonemes object from a model config
If characters are not defined in the config, it will be set to the default characters and the config
will be updated.
"""
# band-aid for compatibility with old models
if "characters" in config and config.characters is not None:
if "phonemes" in config.characters and config.characters.phonemes is not None:
config.characters["characters"] = config.characters["phonemes"]
return (
IPAPhonemes(
characters=config.characters["characters"],
punctuations=config.characters["punctuations"],
pad=config.characters["pad"],
eos=config.characters["eos"],
bos=config.characters["bos"],
blank=config.characters["blank"],
is_unique=config.characters["is_unique"],
is_sorted=config.characters["is_sorted"],
),
config,
)
# use character set from config
if config.characters is not None:
return IPAPhonemes(**config.characters), config
# return default character set
characters = IPAPhonemes()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
class Graphemes(BaseCharacters):
"""🐸Graphemes class to manage `TTS.tts` model vocabulary
Intended to be used with models using graphemes as input.
It uses system defaults for the undefined class arguments.
Args:
characters (str):
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`.
punctuations (str):
Characters to be treated as punctuation. Defaults to `_punctuations`.
pad (str):
Special padding character that would be ignored by the model. Defaults to `_pad`.
eos (str):
End of the sentence character. Defaults to `_eos`.
bos (str):
Beginning of the sentence character. Defaults to `_bos`.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
is_sorted (bool):
Sort the characters in alphabetical order. Defaults to True.
"""
def __init__(
self,
characters: str = _characters,
punctuations: str = _punctuations,
pad: str = _pad,
eos: str = _eos,
bos: str = _bos,
blank: str = _blank,
is_unique: bool = False,
is_sorted: bool = True,
) -> None:
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
@staticmethod
def init_from_config(config: "Coqpit"):
"""Init a Graphemes object from a model config
If characters are not defined in the config, it will be set to the default characters and the config
will be updated.
"""
if config.characters is not None:
# band-aid for compatibility with old models
if "phonemes" in config.characters:
return (
Graphemes(
characters=config.characters["characters"],
punctuations=config.characters["punctuations"],
pad=config.characters["pad"],
eos=config.characters["eos"],
bos=config.characters["bos"],
blank=config.characters["blank"],
is_unique=config.characters["is_unique"],
is_sorted=config.characters["is_sorted"],
),
config,
)
return Graphemes(**config.characters), config
characters = Graphemes()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
if __name__ == "__main__":
gr = Graphemes()
ph = IPAPhonemes()
gr.print_log()
ph.print_log()

View File

@ -19,7 +19,7 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str:
return phoneme + tone
def chinese_text_to_phonemes(text: str) -> str:
def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str:
tokenized_text = jieba.cut(text, HMM=False)
tokenized_text = " ".join(tokenized_text)
pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text)
@ -34,4 +34,4 @@ def chinese_text_to_phonemes(text: str) -> str:
else: # is ponctuation or other
results += list(token)
return "|".join(results)
return seperator.join(results)

View File

@ -1,12 +1,16 @@
"""Set of default text cleaners"""
# TODO: pick the cleaner for languages dynamically
import re
from anyascii import anyascii
from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text
from .abbreviations import abbreviations_en, abbreviations_fr
from .number_norm import normalize_numbers
from .time import expand_time_english
from .english.abbreviations import abbreviations_en
from .english.number_norm import normalize_numbers as en_normalize_numbers
from .english.time_norm import expand_time_english
from .french.abbreviations import abbreviations_fr
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
@ -22,10 +26,6 @@ def expand_abbreviations(text, lang="en"):
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
@ -92,7 +92,17 @@ def english_cleaners(text):
# text = convert_to_ascii(text)
text = lowercase(text)
text = expand_time_english(text)
text = expand_numbers(text)
text = en_normalize_numbers(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = en_normalize_numbers(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
@ -126,17 +136,6 @@ def chinese_mandarin_cleaners(text: str) -> str:
return text
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
# text = convert_to_ascii(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text):
"""Pipeline for multilingual text"""
text = lowercase(text)

View File

@ -0,0 +1,26 @@
import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]

View File

@ -1,30 +1,5 @@
import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
# List of (regular expression, replacement) pairs for abbreviations in french:
abbreviations_fr = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])

View File

@ -0,0 +1,57 @@
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)}
ESPEAK_LANGS = list(ESpeak.supported_languages().keys())
GRUUT_LANGS = list(Gruut.supported_languages())
# Dict setting default phonemizers for each language
DEF_LANG_TO_PHONEMIZER = {
"ja-jp": JA_JP_Phonemizer.name(),
"zh-cn": ZH_CN_Phonemizer.name(),
}
# Add Gruut languages
_ = [Gruut.name()] * len(GRUUT_LANGS)
_new_dict = dict(list(zip(GRUUT_LANGS, _)))
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
# Add ESpeak languages and override any existing ones
_ = [ESpeak.name()] * len(ESPEAK_LANGS)
_new_dict = dict(list(zip(list(ESPEAK_LANGS), _)))
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"]
def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
"""Initiate a phonemizer by name
Args:
name (str):
Name of the phonemizer that should match `phonemizer.name()`.
kwargs (dict):
Extra keyword arguments that should be passed to the phonemizer.
"""
if name == "espeak":
return ESpeak(**kwargs)
if name == "gruut":
return Gruut(**kwargs)
if name == "zh_cn_phonemizer":
return ZH_CN_Phonemizer(**kwargs)
if name == "ja_jp_phonemizer":
return JA_JP_Phonemizer(**kwargs)
raise ValueError(f"Phonemizer {name} not found")
if __name__ == "__main__":
print(DEF_LANG_TO_PHONEMIZER)

View File

@ -0,0 +1,141 @@
import abc
from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation
class BasePhonemizer(abc.ABC):
"""Base phonemizer class
Phonemization follows the following steps:
1. Preprocessing:
- remove empty lines
- remove punctuation
- keep track of punctuation marks
2. Phonemization:
- convert text to phonemes
3. Postprocessing:
- join phonemes
- restore punctuation marks
Args:
language (str):
Language used by the phonemizer.
punctuations (List[str]):
List of punctuation marks to be preserved.
keep_puncs (bool):
Whether to preserve punctuation marks or not.
"""
def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
# ensure the backend is installed on the system
if not self.is_available():
raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
# ensure the backend support the requested language
self._language = self._init_language(language)
# setup punctuation processing
self._keep_puncs = keep_puncs
self._punctuator = Punctuation(punctuations)
def _init_language(self, language):
"""Language initialization
This method may be overloaded in child classes (see Segments backend)
"""
if not self.is_supported_language(language):
raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
return language
@property
def language(self):
"""The language code configured to be used for phonemization"""
return self._language
@staticmethod
@abc.abstractmethod
def name():
"""The name of the backend"""
...
@classmethod
@abc.abstractmethod
def is_available(cls):
"""Returns True if the backend is installed, False otherwise"""
...
@classmethod
@abc.abstractmethod
def version(cls):
"""Return the backend version as a tuple (major, minor, patch)"""
...
@staticmethod
@abc.abstractmethod
def supported_languages():
"""Return a dict of language codes -> name supported by the backend"""
...
def is_supported_language(self, language):
"""Returns True if `language` is supported by the backend"""
return language in self.supported_languages()
@abc.abstractmethod
def _phonemize(self, text, separator):
"""The main phonemization method"""
def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
"""Preprocess the text before phonemization
1. remove spaces
2. remove punctuation
Override this if you need a different behaviour
"""
text = text.strip()
if self._keep_puncs:
# a tuple (text, punctuation marks)
return self._punctuator.strip_to_restore(text)
return [self._punctuator.strip(text)], []
def _phonemize_postprocess(self, phonemized, punctuations) -> str:
"""Postprocess the raw phonemized output
Override this if you need a different behaviour
"""
if self._keep_puncs:
return self._punctuator.restore(phonemized, punctuations)[0]
return phonemized[0]
def phonemize(self, text: str, separator="|") -> str:
"""Returns the `text` phonemized for the given language
Args:
text (str):
Text to be phonemized.
separator (str):
string separator used between phonemes. Default to '_'.
Returns:
(str): Phonemized text
"""
text, punctuations = self._phonemize_preprocess(text)
phonemized = []
for t in text:
p = self._phonemize(t, separator)
phonemized.append(p)
phonemized = self._phonemize_postprocess(phonemized, punctuations)
return phonemized
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}")
print(f"{indent}| > phoneme backend: {self.name()}")

View File

@ -0,0 +1,225 @@
import logging
import subprocess
from typing import Dict, List
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
def is_tool(name):
from shutil import which
return which(name) is not None
# priority: espeakng > espeak
if is_tool("espeak-ng"):
_DEF_ESPEAK_LIB = "espeak-ng"
elif is_tool("espeak"):
_DEF_ESPEAK_LIB = "espeak"
else:
_DEF_ESPEAK_LIB = None
def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
"""Run espeak with the given arguments."""
cmd = [
espeak_lib,
"-q",
"-b",
"1", # UTF8 text encoding
]
cmd.extend(args)
logging.debug("espeakng: executing %s", repr(cmd))
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as p:
res = iter(p.stdout.readline, b"")
if not sync:
p.stdout.close()
if p.stderr:
p.stderr.close()
if p.stdin:
p.stdin.close()
return res
res2 = []
for line in res:
res2.append(line)
p.stdout.close()
if p.stderr:
p.stderr.close()
if p.stdin:
p.stdin.close()
p.wait()
return res2
class ESpeak(BasePhonemizer):
"""ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P
Args:
language (str):
Valid language code for the used backend.
backend (str):
Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically
prefering `espeak-ng` over `espeak`. Defaults to None.
punctuations (str):
Characters to be treated as punctuation. Defaults to Punctuation.default_puncs().
keep_puncs (bool):
If True, keep the punctuations after phonemization. Defaults to True.
Example:
>>> from TTS.tts.utils.text.phonemizers import ESpeak
>>> phonemizer = ESpeak("tr")
>>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|")
'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.'
"""
_ESPEAK_LIB = _DEF_ESPEAK_LIB
def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True):
if self._ESPEAK_LIB is None:
raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.")
self.backend = self._ESPEAK_LIB
# band-aid for backwards compatibility
if language == "en":
language = "en-us"
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
if backend is not None:
self.backend = backend
@property
def backend(self):
return self._ESPEAK_LIB
@backend.setter
def backend(self, backend):
if backend not in ["espeak", "espeak-ng"]:
raise Exception("Unknown backend: %s" % backend)
self._ESPEAK_LIB = backend
# skip first two characters of the retuned text
# "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# ^^
self.num_skip_chars = 2
if backend == "espeak-ng":
# skip the first character of the retuned text
# "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# ^
self.num_skip_chars = 1
def auto_set_espeak_lib(self) -> None:
if is_tool("espeak-ng"):
self._ESPEAK_LIB = "espeak-ng"
elif is_tool("espeak"):
self._ESPEAK_LIB = "espeak"
else:
raise Exception("Cannot set backend automatically. espeak-ng or espeak not found")
@staticmethod
def name():
return "espeak"
def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str:
"""Convert input text to phonemes.
Args:
text (str):
Text to be converted to phonemes.
tie (bool, optional) : When True use a '͡' character between
consecutive characters of a single phoneme. Else separate phoneme
with '_'. This option requires espeak>=1.49. Default to False.
"""
# set arguments
args = ["-v", f"{self._language}"]
# espeak and espeak-ng parses `ipa` differently
if tie:
# use '͡' between phonemes
if self.backend == "espeak":
args.append("--ipa=1")
else:
args.append("--ipa=3")
else:
# split with '_'
if self.backend == "espeak":
args.append("--ipa=3")
else:
args.append("--ipa=1")
if tie:
args.append("--tie=%s" % tie)
args.append('"' + text + '"')
# compute phonemes
phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line))
phonemes += line.decode("utf8").strip()[self.num_skip_chars :] # skip initial redundant characters
return phonemes.replace("_", separator)
def _phonemize(self, text, separator=None):
return self.phonemize_espeak(text, separator, tie=False)
@staticmethod
def supported_languages() -> Dict:
"""Get a dictionary of supported languages.
Returns:
Dict: Dictionary of language codes.
"""
if _DEF_ESPEAK_LIB is None:
return {}
args = ["--voices"]
langs = {}
count = 0
for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True):
line = line.decode("utf8").strip()
if count > 0:
cols = line.split()
lang_code = cols[1]
lang_name = cols[3]
langs[lang_code] = lang_name
logging.debug("line: %s", repr(line))
count += 1
return langs
def version(self) -> str:
"""Get the version of the used backend.
Returns:
str: Version of the used backend.
"""
args = ["--version"]
for line in _espeak_exe(self.backend, args, sync=True):
version = line.decode("utf8").strip().split()[2]
logging.debug("line: %s", repr(line))
return version
@classmethod
def is_available(cls):
"""Return true if ESpeak is available else false"""
return is_tool("espeak") or is_tool("espeak-ng")
if __name__ == "__main__":
e = ESpeak(language="en-us")
print(e.supported_languages())
print(e.version())
print(e.language)
print(e.name())
print(e.is_available())
e = ESpeak(language="en-us", keep_puncs=False)
print("`" + e.phonemize("hello how are you today?") + "`")
e = ESpeak(language="en-us", keep_puncs=True)
print("`" + e.phonemize("hello how are you today?") + "`")

View File

@ -0,0 +1,151 @@
import importlib
from typing import List
import gruut
from gruut_ipa import IPA
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
# Table for str.translate to fix gruut/TTS phoneme mismatch
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
class Gruut(BasePhonemizer):
"""Gruut wrapper for G2P
Args:
language (str):
Valid language code for the used backend.
punctuations (str):
Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
keep_puncs (bool):
If true, keep the punctuations after phonemization. Defaults to True.
use_espeak_phonemes (bool):
If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
keep_stress (bool):
If true, keep the stress characters after phonemization. Defaults to False.
Example:
>>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
>>> phonemizer = Gruut('en-us')
>>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
"""
def __init__(
self,
language: str,
punctuations=Punctuation.default_puncs(),
keep_puncs=True,
use_espeak_phonemes=False,
keep_stress=False,
):
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
self.use_espeak_phonemes = use_espeak_phonemes
self.keep_stress = keep_stress
@staticmethod
def name():
return "gruut"
def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
"""Convert input text to phonemes.
Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
that constitude a single sound.
It doesn't affect 🐸TTS since it individually converts each character to token IDs.
Examples::
"hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
Args:
text (str):
Text to be converted to phonemes.
tie (bool, optional) : When True use a '͡' character between
consecutive characters of a single phoneme. Else separate phoneme
with '_'. This option requires espeak>=1.49. Default to False.
"""
ph_list = []
for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
for word in sentence:
if word.is_break:
# Use actual character for break phoneme (e.g., comma)
if ph_list:
# Join with previous word
ph_list[-1].append(word.text)
else:
# First word is punctuation
ph_list.append([word.text])
elif word.phonemes:
# Add phonemes for word
word_phonemes = []
for word_phoneme in word.phonemes:
if not self.keep_stress:
# Remove primary/secondary stress
word_phoneme = IPA.without_stress(word_phoneme)
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
if word_phoneme:
# Flatten phonemes
word_phonemes.extend(word_phoneme)
if word_phonemes:
ph_list.append(word_phonemes)
ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
ph = f"{separator} ".join(ph_words)
return ph
def _phonemize(self, text, separator):
return self.phonemize_gruut(text, separator, tie=False)
def is_supported_language(self, language):
"""Returns True if `language` is supported by the backend"""
return gruut.is_language_supported(language)
@staticmethod
def supported_languages() -> List:
"""Get a dictionary of supported languages.
Returns:
List: List of language codes.
"""
return list(gruut.get_supported_languages())
def version(self):
"""Get the version of the used backend.
Returns:
str: Version of the used backend.
"""
return gruut.__version__
@classmethod
def is_available(cls):
"""Return true if ESpeak is available else false"""
return importlib.util.find_spec("gruut") is not None
if __name__ == "__main__":
e = Gruut(language="en-us")
print(e.supported_languages())
print(e.version())
print(e.language)
print(e.name())
print(e.is_available())
e = Gruut(language="en-us", keep_puncs=False)
print("`" + e.phonemize("hello how are you today?") + "`")
e = Gruut(language="en-us", keep_puncs=True)
print("`" + e.phonemize("hello how, are you today?") + "`")

View File

@ -0,0 +1,72 @@
from typing import Dict
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】"
_TRANS_TABLE = {"": ","}
def trans(text):
for i, j in _TRANS_TABLE.items():
text = text.replace(i, j)
return text
class JA_JP_Phonemizer(BasePhonemizer):
"""🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer`
TODO: someone with JA knowledge should check this implementation
Example:
>>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer
>>> phonemizer = JA_JP_Phonemizer()
>>> phonemizer.phonemize("どちらに行きますか?", separator="|")
'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?'
"""
language = "ja-jp"
def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
@staticmethod
def name():
return "ja_jp_phonemizer"
def _phonemize(self, text: str, separator: str = "|") -> str:
ph = japanese_text_to_phonemes(text)
if separator is not None or separator != "":
return separator.join(ph)
return ph
def phonemize(self, text: str, separator="|") -> str:
"""Custom phonemize for JP_JA
Skip pre-post processing steps used by the other phonemizers.
"""
return self._phonemize(text, separator)
@staticmethod
def supported_languages() -> Dict:
return {"ja-jp": "Japanese (Japan)"}
def version(self) -> str:
return "0.0.1"
def is_available(self) -> bool:
return True
# if __name__ == "__main__":
# text = "これは、電話をかけるための私の日本語の例のテキストです。"
# e = JA_JP_Phonemizer()
# print(e.supported_languages())
# print(e.version())
# print(e.language)
# print(e.name())
# print(e.is_available())
# print("`" + e.phonemize(text) + "`")

View File

@ -0,0 +1,55 @@
from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
class MultiPhonemizer:
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
Args:
custom_lang_to_phonemizer (Dict):
Custom phonemizer mapping if you want to change the defaults. In the format of
`{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`.
TODO: find a way to pass custom kwargs to the phonemizers
"""
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER
language = "multi-lingual"
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer)
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items():
phonemizer = get_phonemizer_by_name(v, language=k)
lang_to_phonemizer[k] = phonemizer
return lang_to_phonemizer
@staticmethod
def name():
return "multi-phonemizer"
def phonemize(self, text, language, separator="|"):
return self.lang_to_phonemizer[language].phonemize(text, separator)
def supported_languages(self) -> List:
return list(self.lang_to_phonemizer_name.keys())
# if __name__ == "__main__":
# texts = {
# "tr": "Merhaba, bu Türkçe bit örnek!",
# "en-us": "Hello, this is English example!",
# "de": "Hallo, das ist ein Deutches Beipiel!",
# "zh-cn": "这是中国的例子",
# }
# phonemes = {}
# ph = MultiPhonemizer()
# for lang, text in texts.items():
# phoneme = ph.phonemize(text, lang)
# phonemes[lang] = phoneme
# print(phonemes)

View File

@ -0,0 +1,62 @@
from typing import Dict
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】"
class ZH_CN_Phonemizer(BasePhonemizer):
"""🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer`
Args:
punctuations (str):
Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`.
keep_puncs (bool):
If True, keep the punctuations after phonemization. Defaults to False.
Example ::
"这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| || |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |`
TODO: someone with Mandarin knowledge should check this implementation
"""
language = "zh-cn"
def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
@staticmethod
def name():
return "zh_cn_phonemizer"
@staticmethod
def phonemize_zh_cn(text: str, separator: str = "|") -> str:
ph = chinese_text_to_phonemes(text, separator)
return ph
def _phonemize(self, text, separator):
return self.phonemize_zh_cn(text, separator)
@staticmethod
def supported_languages() -> Dict:
return {"zh-cn": "Japanese (Japan)"}
def version(self) -> str:
return "0.0.1"
def is_available(self) -> bool:
return True
# if __name__ == "__main__":
# text = "这是,样本中文。"
# e = ZH_CN_Phonemizer()
# print(e.supported_languages())
# print(e.version())
# print(e.language)
# print(e.name())
# print(e.is_available())
# print("`" + e.phonemize(text) + "`")

View File

@ -0,0 +1,172 @@
import collections
import re
from enum import Enum
import six
_DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
class PuncPosition(Enum):
"""Enum for the punctuations positions"""
BEGIN = 0
END = 1
MIDDLE = 2
ALONE = 3
class Punctuation:
"""Handle punctuations in text.
Just strip punctuations from text or strip and restore them later.
Args:
puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
Example:
>>> punc = Punctuation()
>>> punc.strip("This is. example !")
'This is example'
>>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
>>> ' '.join(text_striped)
'This is example'
>>> text_restored = punc.restore(text_striped, punc_map)
>>> text_restored[0]
'This is. example !'
"""
def __init__(self, puncs: str = _DEF_PUNCS):
self.puncs = puncs
@staticmethod
def default_puncs():
"""Return default set of punctuations."""
return _DEF_PUNCS
@property
def puncs(self):
return self._puncs
@puncs.setter
def puncs(self, value):
if not isinstance(value, six.string_types):
raise ValueError("[!] Punctuations must be of type str.")
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
def strip(self, text):
"""Remove all the punctuations by replacing with `space`.
Args:
text (str): The text to be processed.
Example::
"This is. example !" -> "This is example "
"""
return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
def strip_to_restore(self, text):
"""Remove punctuations from text to restore them later.
Args:
text (str): The text to be processed.
Examples ::
"This is. example !" -> [["This is", "example"], [".", "!"]]
"""
text, puncs = self._strip_to_restore(text)
return text, puncs
def _strip_to_restore(self, text):
"""Auxiliary method for Punctuation.preserve()"""
matches = list(re.finditer(self.puncs_regular_exp, text))
if not matches:
return [text], []
# the text is only punctuations
if len(matches) == 1 and matches[0].group() == text:
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
# build a punctuation map to be used later to restore punctuations
puncs = []
for match in matches:
position = PuncPosition.MIDDLE
if match == matches[0] and text.startswith(match.group()):
position = PuncPosition.BEGIN
elif match == matches[-1] and text.endswith(match.group()):
position = PuncPosition.END
puncs.append(_PUNC_IDX(match.group(), position))
# convert str text to a List[str], each item is separated by a punctuation
splitted_text = []
for idx, punc in enumerate(puncs):
split = text.split(punc.punc)
prefix, suffix = split[0], punc.punc.join(split[1:])
splitted_text.append(prefix)
# if the text does not end with a punctuation, add it to the last item
if idx == len(puncs) - 1 and len(suffix) > 0:
splitted_text.append(suffix)
text = suffix
return splitted_text, puncs
@classmethod
def restore(cls, text, puncs):
"""Restore punctuation in a text.
Args:
text (str): The text to be processed.
puncs (List[str]): The list of punctuations map to be used for restoring.
Examples ::
['This is', 'example'], ['.', '!'] -> "This is. example!"
"""
return cls._restore(text, puncs, 0)
@classmethod
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
"""Auxiliary method for Punctuation.restore()"""
if not puncs:
return text
# nothing have been phonemized, returns the puncs alone
if not text:
return ["".join(m.mark for m in puncs)]
current = puncs[0]
if current.position == PuncPosition.BEGIN:
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
if current.position == PuncPosition.END:
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
if current.position == PuncPosition.ALONE:
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
# POSITION == MIDDLE
if len(text) == 1: # pragma: nocover
# a corner case where the final part of an intermediate
# mark (I) has not been phonemized
return cls._restore([text[0] + current.punc], puncs[1:], num)
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
# if __name__ == "__main__":
# punc = Punctuation()
# text = "This is. This is, example!"
# print(punc.strip(text))
# split_text, puncs = punc.strip_to_restore(text)
# print(split_text, " ---- ", puncs)
# restored_text = punc.restore(split_text, puncs)
# print(restored_text)

View File

@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
"""
def make_symbols(
characters,
phonemes=None,
punctuations="!'(),-.:;? ",
pad="_",
eos="~",
bos="^",
unique=True,
): # pylint: disable=redefined-outer-name
"""Function to create symbols and phonemes
TODO: create phonemes_to_id and symbols_to_id dicts here."""
_symbols = list(characters)
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
_phonemes = None
if phonemes is not None:
_phonemes_sorted = (
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
) # this is to keep previous models compatible.
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
# _symbols += _arpabet
return _symbols, _phonemes
_pad = "_"
_eos = "~"
_bos = "^"
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
_punctuations = "!'(),-.:;? "
# Phonemes definition (All IPA characters)
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
# Generate ALIEN language
# from random import shuffle
# shuffle(phonemes)
def parse_symbols():
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
if __name__ == "__main__":
print(" > TTS symbols {}".format(len(symbols)))
print(symbols)
print(" > TTS phonemes {}".format(len(phonemes)))
print("".join(sorted(phonemes)))

View File

@ -0,0 +1,205 @@
from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.utils.generic_utils import get_import_path, import_class
class TTSTokenizer:
"""🐸TTS tokenizer to convert input characters to token IDs and back.
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
Args:
use_phonemes (bool):
Whether to use phonemes instead of characters. Defaults to False.
characters (Characters):
A Characters object to use for character-to-ID and ID-to-character mappings.
text_cleaner (callable):
A function to pre-process the text before tokenization and phonemization. Defaults to None.
phonemizer (Phonemizer):
A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None.
Example:
>>> from TTS.tts.utils.text.tokenizer import TTSTokenizer
>>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes())
>>> text = "Hello world!"
>>> ids = tokenizer.text_to_ids(text)
>>> text_hat = tokenizer.ids_to_text(ids)
>>> assert text == text_hat
"""
def __init__(
self,
use_phonemes=False,
text_cleaner: Callable = None,
characters: "BaseCharacters" = None,
phonemizer: Union["Phonemizer", Dict] = None,
add_blank: bool = False,
use_eos_bos=False,
):
self.text_cleaner = text_cleaner
self.use_phonemes = use_phonemes
self.add_blank = add_blank
self.use_eos_bos = use_eos_bos
self.characters = characters
self.not_found_characters = []
self.phonemizer = phonemizer
@property
def characters(self):
return self._characters
@characters.setter
def characters(self, new_characters):
self._characters = new_characters
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
def encode(self, text: str) -> List[int]:
"""Encodes a string of text as a sequence of IDs."""
token_ids = []
for char in text:
try:
idx = self.characters.char_to_id(char)
token_ids.append(idx)
except KeyError:
# discard but store not found characters
if char not in self.not_found_characters:
self.not_found_characters.append(char)
print(text)
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
return token_ids
def decode(self, token_ids: List[int]) -> str:
"""Decodes a sequence of IDs to a string of text."""
text = ""
for token_id in token_ids:
text += self.characters.id_to_char(token_id)
return text
def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
"""Converts a string of text to a sequence of token IDs.
Args:
text(str):
The text to convert to token IDs.
language(str):
The language code of the text. Defaults to None.
TODO:
- Add support for language-specific processing.
1. Text normalizatin
2. Phonemization (if use_phonemes is True)
3. Add blank char between characters
4. Add BOS and EOS characters
5. Text to token IDs
"""
# TODO: text cleaner should pick the right routine based on the language
if self.text_cleaner is not None:
text = self.text_cleaner(text)
if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="")
if self.add_blank:
text = self.intersperse_blank_char(text, True)
if self.use_eos_bos:
text = self.pad_with_bos_eos(text)
return self.encode(text)
def ids_to_text(self, id_sequence: List[int]) -> str:
"""Converts a sequence of token IDs to a string of text."""
return self.decode(id_sequence)
def pad_with_bos_eos(self, char_sequence: List[str]):
"""Pads a sequence with the special BOS and EOS characters."""
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
"""
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence
return result
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > add_blank: {self.add_blank}")
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
if self.use_phonemes:
print(f"{indent}| > phonemizer:")
self.phonemizer.print_logs(level + 1)
if len(self.not_found_characters) > 0:
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
for char in self.not_found_characters:
print(f"{indent}| > {char}")
@staticmethod
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
"""Init Tokenizer object from config
Args:
config (Coqpit): Coqpit model config.
characters (BaseCharacters): Defines the model character set. If not set, use the default options based on
the config values. Defaults to None.
"""
# init cleaners
text_cleaner = None
if isinstance(config.text_cleaner, (str, list)):
text_cleaner = getattr(cleaners, config.text_cleaner)
# init characters
if characters is None:
# set characters based on defined characters class
if config.characters and config.characters.characters_class:
CharactersClass = import_class(config.characters.characters_class)
characters, new_config = CharactersClass.init_from_config(config)
# set characters based on config
else:
if config.use_phonemes:
# init phoneme set
characters, new_config = IPAPhonemes().init_from_config(config)
else:
# init character set
characters, new_config = Graphemes().init_from_config(config)
else:
characters, new_config = characters.init_from_config(config)
# set characters class
new_config.characters.characters_class = get_import_path(characters)
# init phonemizer
phonemizer = None
if config.use_phonemes:
phonemizer_kwargs = {"language": config.phoneme_language}
if "phonemizer" in config and config.phonemizer:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
else:
try:
phonemizer = get_phonemizer_by_name(
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
)
except KeyError as e:
raise ValueError(
f"""No phonemizer found for language {config.phoneme_language}.
You may need to install a third party library for this language."""
) from e
return (
TTSTokenizer(
config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars
),
new_config,
)

View File

@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
matplotlib.use("Agg")
@ -89,12 +87,46 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False)
return fig
def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
"""Plot pitch curves on top of the input characters.
Args:
pitch (np.array): Pitch values.
chars (str): Characters to place to the x-axis.
Shapes:
pitch: :math:`(T,)`
"""
old_fig_size = plt.rcParams["figure.figsize"]
if fig_size is not None:
plt.rcParams["figure.figsize"] = fig_size
fig, ax = plt.subplots()
x = np.array(range(len(chars)))
my_xticks = chars
plt.xticks(x, my_xticks)
ax.set_xlabel("characters")
ax.set_ylabel("freq")
ax2 = ax.twinx()
ax2.plot(pitch, linewidth=5.0, color="red")
ax2.set_ylabel("F0")
plt.rcParams["figure.figsize"] = old_fig_size
if not output_fig:
plt.close()
return fig
def visualize(
alignment,
postnet_output,
text,
hop_length,
CONFIG,
tokenizer,
stop_tokens=None,
decoder_output=None,
output_path=None,
@ -117,14 +149,8 @@ def visualize(
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
# compute phoneme representation and back
if CONFIG.use_phonemes:
seq = phoneme_to_sequence(
text,
[CONFIG.text_cleaner],
CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
)
text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None)
seq = tokenizer.text_to_ids(text)
text = tokenizer.ids_to_text(seq)
print(text)
plt.yticks(range(len(text)), list(text))
plt.colorbar()

View File

@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S ** self.power
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
@ -239,6 +239,12 @@ class AudioProcessor(object):
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
pitch_fmin (int, optional):
minimum filter frequency for computing pitch. Defaults to None.
pitch_fmax (int, optional):
maximum filter frequency for computing pitch. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
@ -300,6 +306,8 @@ class AudioProcessor(object):
max_norm=None,
mel_fmin=None,
mel_fmax=None,
pitch_fmax=None,
pitch_fmin=None,
spec_gain=20,
stft_pad_mode="reflect",
clip_norm=True,
@ -333,6 +341,8 @@ class AudioProcessor(object):
self.symmetric_norm = symmetric_norm
self.mel_fmin = mel_fmin or 0
self.mel_fmax = mel_fmax
self.pitch_fmin = pitch_fmin
self.pitch_fmax = pitch_fmax
self.spec_gain = float(spec_gain)
self.stft_pad_mode = stft_pad_mode
self.max_norm = 1.0 if max_norm is None else float(max_norm)
@ -379,6 +389,12 @@ class AudioProcessor(object):
self.clip_norm = None
self.symmetric_norm = None
@staticmethod
def init_from_config(config: "Coqpit", verbose=True):
if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config)
### setting up the parameters ###
def _build_mel_basis(
self,
@ -634,8 +650,8 @@ class AudioProcessor(object):
S = self._db_to_amp(S)
# Reconstruct phase
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
@ -643,8 +659,8 @@ class AudioProcessor(object):
S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -720,11 +736,12 @@ class AudioProcessor(object):
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(mel_fmax=8000)
>>> conf = BaseAudioConfig(pitch_fmax=8000)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
>>> pitch = ap.compute_f0(wav)
"""
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
# align F0 length to the spectrogram length
if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode="reflect")
@ -732,7 +749,7 @@ class AudioProcessor(object):
f0, t = pw.dio(
x.astype(np.double),
fs=self.sample_rate,
f0_ceil=self.mel_fmax,
f0_ceil=self.pitch_fmax,
frame_period=1000 * self.hop_length / self.sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
@ -781,7 +798,7 @@ class AudioProcessor(object):
@staticmethod
def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
@ -853,7 +870,7 @@ class AudioProcessor(object):
@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2 ** qc - 1
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
@ -865,13 +882,13 @@ class AudioProcessor(object):
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2 ** qc - 1
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
@ -884,12 +901,12 @@ class AudioProcessor(object):
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2 ** bits - 1) / 2
return (x + 1.0) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2 ** bits - 1) - 1
return 2 * x / (2**bits - 1) - 1
def _log(x, base):

View File

@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
chunk = file_obj.read(1024**2)
if not chunk:
break
hash_func.update(chunk)

Some files were not shown because too many files have changed in this diff Show More