mirror of https://github.com/coqui-ai/TTS.git
commit
209ee40c88
16
.compute
16
.compute
|
@ -1,16 +0,0 @@
|
|||
#!/bin/bash
|
||||
yes | apt-get install sox
|
||||
yes | apt-get install ffmpeg
|
||||
yes | apt-get install tmux
|
||||
yes | apt-get install zsh
|
||||
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
|
||||
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
|
||||
sudo sh install.sh
|
||||
# pip install pytorch==1.7.0+cu100
|
||||
# python3 setup.py develop
|
||||
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
|
||||
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
|
||||
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
|
||||
# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
|
||||
# python3 distribute.py --config_path config.json
|
||||
while true; do sleep 1000000; done
|
|
@ -0,0 +1,46 @@
|
|||
name: data-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make data_tests
|
|
@ -0,0 +1,46 @@
|
|||
name: inference_tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make inference_tests
|
|
@ -0,0 +1,48 @@
|
|||
name: tts-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_text
|
|
@ -35,6 +35,8 @@ jobs:
|
|||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
|
|
|
@ -35,6 +35,7 @@ jobs:
|
|||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
|
|
|
@ -164,4 +164,5 @@ internal/*
|
|||
*_pitch.npy
|
||||
*_phoneme.npy
|
||||
wandb
|
||||
depot/*
|
||||
depot/*
|
||||
coqui_recipes/*
|
|
@ -168,7 +168,8 @@ disable=missing-docstring,
|
|||
exception-escape,
|
||||
comprehension-escape,
|
||||
duplicate-code,
|
||||
not-callable
|
||||
not-callable,
|
||||
import-outside-toplevel
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
|
|
10
Makefile
10
Makefile
|
@ -26,6 +26,15 @@ test_aux: ## run aux tests.
|
|||
test_zoo: ## run zoo tests.
|
||||
nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id
|
||||
|
||||
inference_tests: ## run inference tests.
|
||||
nosetests tests.inference_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.inference_tests --nologcapture --with-id
|
||||
|
||||
data_tests: ## run data tests.
|
||||
nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id
|
||||
|
||||
test_text: ## run text tests.
|
||||
nosetests tests.text_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.text_tests --nologcapture --with-id
|
||||
|
||||
test_failed: ## only run tests failed the last time.
|
||||
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed
|
||||
|
||||
|
@ -41,7 +50,6 @@ system-deps: ## install linux system deps
|
|||
|
||||
dev-deps: ## install development deps
|
||||
pip install -r requirements.dev.txt
|
||||
pip install -r requirements.tf.txt
|
||||
|
||||
doc-deps: ## install docs dependencies
|
||||
pip install -r docs/requirements.txt
|
||||
|
|
13
README.md
13
README.md
|
@ -61,8 +61,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
- Detailed training logs on the terminal and Tensorboard.
|
||||
- Support for Multi-speaker TTS.
|
||||
- Efficient, flexible, lightweight but feature complete `Trainer API`.
|
||||
- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
|
||||
- Released and read-to-use models.
|
||||
- Released and ready-to-use models.
|
||||
- Tools to curate Text2Speech datasets under```dataset_analysis```.
|
||||
- Utilities to use and test your models.
|
||||
- Modular (but not too much) code base enabling easy implementation of new ideas.
|
||||
|
@ -113,17 +112,11 @@ If you are only interested in [synthesizing speech](https://tts.readthedocs.io/e
|
|||
pip install TTS
|
||||
```
|
||||
|
||||
By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra.
|
||||
|
||||
```bash
|
||||
pip install TTS[tf]
|
||||
```
|
||||
|
||||
If you plan to code or train models, clone 🐸TTS and install it locally.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/coqui-ai/TTS
|
||||
pip install -e .[all,dev,notebooks,tf] # Select the relevant extras
|
||||
pip install -e .[all,dev,notebooks] # Select the relevant extras
|
||||
```
|
||||
|
||||
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
||||
|
@ -204,12 +197,10 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
|||
|- train*.py (train your target model.)
|
||||
|- distribute.py (train your TTS model using Multiple GPUs.)
|
||||
|- compute_statistics.py (compute dataset statistics for normalization.)
|
||||
|- convert*.py (convert target torch model to TF.)
|
||||
|- ...
|
||||
|- tts/ (text to speech models)
|
||||
|- layers/ (model layer definitions)
|
||||
|- models/ (model definitions)
|
||||
|- tf/ (Tensorflow 2 utilities and model implementations)
|
||||
|- utils/ (model specific utilities.)
|
||||
|- speaker_encoder/ (Speaker Encoder models.)
|
||||
|- (same)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"multi-dataset":{
|
||||
"your_tts":{
|
||||
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "e9a1953e",
|
||||
"license": "CC BY-NC-ND 4.0",
|
||||
|
@ -33,7 +33,7 @@
|
|||
},
|
||||
"tacotron2-DDC_ph": {
|
||||
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip",
|
||||
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
|
||||
"commit": "3900448",
|
||||
"author": "Eren Gölge @erogol",
|
||||
|
@ -71,7 +71,7 @@
|
|||
},
|
||||
"vits": {
|
||||
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--ljspeech--vits.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--ljspeech--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "3900448",
|
||||
"author": "Eren Gölge @erogol",
|
||||
|
@ -89,18 +89,9 @@
|
|||
}
|
||||
},
|
||||
"vctk": {
|
||||
"sc-glow-tts": {
|
||||
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
|
||||
"default_vocoder": "vocoder_models/en/vctk/hifigan_v2",
|
||||
"commit": "b531fa69",
|
||||
"author": "Edresson Casanova",
|
||||
"license": "",
|
||||
"contact": ""
|
||||
},
|
||||
"vits": {
|
||||
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.2.0/tts_models--en--vctk--vits.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "3900448",
|
||||
"author": "Eren @erogol",
|
||||
|
@ -109,7 +100,7 @@
|
|||
},
|
||||
"fast_pitch":{
|
||||
"description": "FastPitch model trained on VCTK dataseset.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--en--vctk--fast_pitch.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "bdab788d",
|
||||
"author": "Eren @erogol",
|
||||
|
@ -156,7 +147,7 @@
|
|||
"uk":{
|
||||
"mai": {
|
||||
"glow-tts": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--uk--mailabs--glow-tts.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--uk--mai--glow-tts.zip",
|
||||
"author":"@robinhad",
|
||||
"commit": "bdab788d",
|
||||
"license": "MIT",
|
||||
|
@ -168,7 +159,7 @@
|
|||
"zh-CN": {
|
||||
"baker": {
|
||||
"tacotron2-DDC-GST": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
|
||||
"commit": "unknown",
|
||||
"author": "@kirianguiller",
|
||||
"default_vocoder": null
|
||||
|
@ -206,6 +197,52 @@
|
|||
"commit": "401fbd89"
|
||||
}
|
||||
}
|
||||
},
|
||||
"tr":{
|
||||
"common-voice": {
|
||||
"glow-tts":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--tr--common-voice--glow-tts.zip",
|
||||
"default_vocoder": "vocoder_models/tr/common-voice/hifigan",
|
||||
"license": "MIT",
|
||||
"description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.",
|
||||
"author": "Fatih Akademi",
|
||||
"commit": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"it": {
|
||||
"mai_female": {
|
||||
"glow-tts":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--glow-tts.zip",
|
||||
"default_vocoder": null,
|
||||
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
|
||||
"author": "@nicolalandro",
|
||||
"commit": null
|
||||
},
|
||||
"vits":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_female--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
|
||||
"author": "@nicolalandro",
|
||||
"commit": null
|
||||
}
|
||||
},
|
||||
"mai_male": {
|
||||
"glow-tts":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--glow-tts.zip",
|
||||
"default_vocoder": null,
|
||||
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
|
||||
"author": "@nicolalandro",
|
||||
"commit": null
|
||||
},
|
||||
"vits":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/tts_models--it--mai_male--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.",
|
||||
"author": "@nicolalandro",
|
||||
"commit": null
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"vocoder_models": {
|
||||
|
@ -324,6 +361,17 @@
|
|||
"contact": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"tr":{
|
||||
"common-voice": {
|
||||
"hifigan":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.0_models/vocoder_models--tr--common-voice--hifigan.zip",
|
||||
"description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.",
|
||||
"author": "Fatih Akademi",
|
||||
"license": "MIT",
|
||||
"commit": null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1 +1 @@
|
|||
0.5.0
|
||||
0.6.0
|
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_checkpoint
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@ parser.add_argument(
|
|||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
|
||||
parser.add_argument(
|
||||
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
|
@ -40,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
|
|||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
|
||||
encoder_model_path=args.model_path,
|
||||
encoder_config_path=args.config_path,
|
||||
d_vectors_file_path=args.old_file,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
# compute speaker embeddings
|
||||
|
@ -52,11 +58,15 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
else:
|
||||
speaker_name = None
|
||||
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
|
||||
# get the embedding from the old file
|
||||
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
|
||||
else:
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd
|
||||
|
|
|
@ -51,7 +51,7 @@ def main():
|
|||
N = 0
|
||||
for item in tqdm(dataset_items):
|
||||
# compute features
|
||||
wav = ap.load_wav(item if isinstance(item, str) else item[1])
|
||||
wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
|
||||
linear = ap.spectrogram(wav)
|
||||
mel = ap.melspectrogram(wav)
|
||||
|
||||
|
@ -59,13 +59,13 @@ def main():
|
|||
N += mel.shape[1]
|
||||
mel_sum += mel.sum(1)
|
||||
linear_sum += linear.sum(1)
|
||||
mel_square_sum += (mel ** 2).sum(axis=1)
|
||||
linear_square_sum += (linear ** 2).sum(axis=1)
|
||||
mel_square_sum += (mel**2).sum(axis=1)
|
||||
linear_square_sum += (linear**2).sum(axis=1)
|
||||
|
||||
mel_mean = mel_sum / N
|
||||
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
|
||||
mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
|
||||
linear_mean = linear_sum / N
|
||||
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
|
||||
linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
|
||||
|
||||
output_file_path = args.out_path
|
||||
stats = {}
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
||||
|
||||
import argparse
|
||||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.tf.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.tf.utils.io import load_checkpoint
|
||||
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set constants
|
||||
CONFIG = load_config(args.config_path)
|
||||
|
||||
# load the model
|
||||
model = setup_generator(CONFIG)
|
||||
model.build_inference()
|
||||
model = load_checkpoint(model, args.tf_model)
|
||||
|
||||
# create tflite model
|
||||
tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)
|
|
@ -1,105 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||
compare_torch_tf,
|
||||
convert_tf_name,
|
||||
transfer_weights_torch_to_tf,
|
||||
)
|
||||
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
|
||||
from TTS.vocoder.tf.utils.io import save_checkpoint
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
|
||||
# prevent GPU use
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
# define args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
config_path = args.config_path
|
||||
c = load_config(config_path)
|
||||
num_speakers = 0
|
||||
|
||||
# init torch model
|
||||
model = setup_generator(c)
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.remove_weight_norm()
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# init tf model
|
||||
model_tf = setup_tf_generator(c)
|
||||
|
||||
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
# get tf_model graph by passing an input
|
||||
# B x D x T
|
||||
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
|
||||
mel_pred = model_tf(dummy_input, training=False)
|
||||
|
||||
# get tf variables
|
||||
tf_vars = model_tf.weights
|
||||
|
||||
# match variable names with fuzzy logic
|
||||
torch_var_names = list(state_dict.keys())
|
||||
tf_var_names = [we.name for we in model_tf.weights]
|
||||
var_map = []
|
||||
for tf_name in tf_var_names:
|
||||
# skip re-mapped layer names
|
||||
if tf_name in [name[0] for name in var_map]:
|
||||
continue
|
||||
tf_name_edited = convert_tf_name(tf_name)
|
||||
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
|
||||
max_idx = np.argmax(ratios)
|
||||
matching_name = torch_var_names[max_idx]
|
||||
del torch_var_names[max_idx]
|
||||
var_map.append((tf_name, matching_name))
|
||||
|
||||
# pass weights
|
||||
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
||||
|
||||
# Compare TF and TORCH models
|
||||
# check embedding outputs
|
||||
model.eval()
|
||||
dummy_input_torch = torch.ones((1, 80, 10))
|
||||
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
|
||||
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
|
||||
|
||||
out_torch = model.layers[0](dummy_input_torch)
|
||||
out_tf = model_tf.model_layers[0](dummy_input_tf)
|
||||
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
||||
|
||||
assert compare_torch_tf(out_torch, out_tf_) < 1e-5
|
||||
|
||||
for i in range(1, len(model.layers)):
|
||||
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
|
||||
out_torch = model.layers[i](out_torch)
|
||||
out_tf = model_tf.model_layers[i](out_tf)
|
||||
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
||||
diff = compare_torch_tf(out_torch, out_tf_)
|
||||
assert diff < 1e-5, diff
|
||||
|
||||
torch.manual_seed(0)
|
||||
dummy_input_torch = torch.rand((1, 80, 100))
|
||||
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||
model.inference_padding = 0
|
||||
model_tf.inference_padding = 0
|
||||
output_torch = model.inference(dummy_input_torch)
|
||||
output_tf = model_tf(dummy_input_tf, training=False)
|
||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf)
|
||||
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path)
|
||||
print(" > Model conversion is successfully completed :).")
|
|
@ -1,30 +0,0 @@
|
|||
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
||||
|
||||
import argparse
|
||||
|
||||
from TTS.tts.tf.utils.generic_utils import setup_model
|
||||
from TTS.tts.tf.utils.io import load_checkpoint
|
||||
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set constants
|
||||
CONFIG = load_config(args.config_path)
|
||||
|
||||
# load the model
|
||||
c = CONFIG
|
||||
num_speakers = 0
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
|
||||
model.build_inference()
|
||||
model = load_checkpoint(model, args.tf_model)
|
||||
model.decoder.set_max_decoder_steps(1000)
|
||||
|
||||
# create tflite model
|
||||
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
|
|
@ -1,187 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from difflib import SequenceMatcher
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
|
||||
sys.path.append("/home/erogol/Projects")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
config_path = args.config_path
|
||||
c = load_config(config_path)
|
||||
num_speakers = 0
|
||||
|
||||
# init torch model
|
||||
model = setup_model(c)
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# init tf model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_tf = Tacotron2(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
out_channels=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
)
|
||||
|
||||
# set initial layer mapping - these are not captured by the below heuristic approach
|
||||
# TODO: set layer names so that we can remove these manual matching
|
||||
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
var_map = [
|
||||
("embedding/embeddings:0", "embedding.weight"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"),
|
||||
("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"),
|
||||
("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")),
|
||||
(
|
||||
"encoder/lstm/backward_lstm/lstm_cell_2/bias:0",
|
||||
("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"),
|
||||
),
|
||||
("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"),
|
||||
("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"),
|
||||
("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"),
|
||||
]
|
||||
|
||||
# %%
|
||||
# get tf_model graph
|
||||
model_tf.build_inference()
|
||||
|
||||
# get tf variables
|
||||
tf_vars = model_tf.weights
|
||||
|
||||
# match variable names with fuzzy logic
|
||||
torch_var_names = list(state_dict.keys())
|
||||
tf_var_names = [we.name for we in model_tf.weights]
|
||||
for tf_name in tf_var_names:
|
||||
# skip re-mapped layer names
|
||||
if tf_name in [name[0] for name in var_map]:
|
||||
continue
|
||||
tf_name_edited = convert_tf_name(tf_name)
|
||||
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
|
||||
max_idx = np.argmax(ratios)
|
||||
matching_name = torch_var_names[max_idx]
|
||||
del torch_var_names[max_idx]
|
||||
var_map.append((tf_name, matching_name))
|
||||
|
||||
pprint(var_map)
|
||||
pprint(torch_var_names)
|
||||
|
||||
# pass weights
|
||||
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
||||
|
||||
# Compare TF and TORCH models
|
||||
# %%
|
||||
# check embedding outputs
|
||||
model.eval()
|
||||
input_ids = torch.randint(0, 24, (1, 128)).long()
|
||||
|
||||
o_t = model.embedding(input_ids)
|
||||
o_tf = model_tf.embedding(input_ids.detach().numpy())
|
||||
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
|
||||
|
||||
# compare encoder outputs
|
||||
oo_en = model.encoder.inference(o_t.transpose(1, 2))
|
||||
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
|
||||
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# compare decoder.attention_rnn
|
||||
inp = torch.rand([1, 768])
|
||||
inp_tf = inp.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
|
||||
output, cell_state = model.decoder.attention_rnn(inp)
|
||||
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
||||
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
query = output
|
||||
inputs = torch.rand([1, 128, 512])
|
||||
query_tf = query.detach().numpy()
|
||||
inputs_tf = inputs.numpy()
|
||||
|
||||
# compare decoder.attention
|
||||
model.decoder.attention.init_states(inputs)
|
||||
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
|
||||
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
|
||||
context = model.decoder.attention(query, inputs, processes_inputs, None)
|
||||
|
||||
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
|
||||
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
|
||||
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
|
||||
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
|
||||
|
||||
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(context, context_tf) < 1e-5
|
||||
|
||||
# compare decoder.decoder_rnn
|
||||
input = torch.rand([1, 1536])
|
||||
input_tf = input.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
|
||||
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
|
||||
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
||||
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
|
||||
assert abs(input - input_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
# compare decoder.linear_projection
|
||||
input = torch.rand([1, 1536])
|
||||
input_tf = input.numpy()
|
||||
output = model.decoder.linear_projection(input)
|
||||
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
|
||||
assert compare_torch_tf(output, output_tf) < 1e-5
|
||||
|
||||
# compare decoder outputs
|
||||
model.decoder.max_decoder_steps = 100
|
||||
model_tf.decoder.set_max_decoder_steps(100)
|
||||
output, align, stop = model.decoder.inference(oo_en)
|
||||
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
||||
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
|
||||
assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
|
||||
|
||||
# compare the whole model output
|
||||
outputs_torch = model.inference(input_ids)
|
||||
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
|
||||
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
|
||||
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
|
||||
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
|
||||
|
||||
# %%
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path)
|
||||
print(" > Model conversion is successfully completed :).")
|
|
@ -7,15 +7,14 @@ import subprocess
|
|||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.trainer import TrainingArgs
|
||||
from trainer import TrainerArgs
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = TrainingArgs().init_argparse(arg_prefix="")
|
||||
parser = TrainerArgs().init_argparse(arg_prefix="")
|
||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||
args, unargs = parser.parse_known_args()
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from TTS.config import load_config
|
|||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
|
@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available()
|
|||
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
tokenizer, _ = TTSTokenizer.init_from_config(c)
|
||||
dataset = TTSDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
outputs_per_step=r,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data,
|
||||
samples=meta_data,
|
||||
tokenizer=tokenizer,
|
||||
ap=ap,
|
||||
characters=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
min_text_len=c.min_text_len,
|
||||
max_text_len=c.max_text_len,
|
||||
min_audio_len=c.min_audio_len,
|
||||
max_audio_len=c.max_audio_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
precompute_num_workers=0,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
|
||||
|
@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False):
|
|||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
|
||||
dataset.preprocess_samples()
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -75,8 +75,8 @@ def set_filename(wav_path, out_path):
|
|||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data["text"]
|
||||
text_lengths = data["text_lengths"]
|
||||
text_input = data["token_id"]
|
||||
text_lengths = data["token_id_lengths"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
item_idx = data["item_idxs"]
|
||||
|
@ -138,7 +138,7 @@ def inference(
|
|||
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
|
||||
)
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
model_output = model_output.detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
|
@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval)
|
||||
meta_data_train, meta_data_eval = load_tts_samples(
|
||||
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
|
||||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
|
|
@ -23,7 +23,10 @@ def main():
|
|||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
|
||||
train_items, eval_items = load_tts_samples(
|
||||
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
|
||||
items = train_items + eval_items
|
||||
|
||||
texts = "".join(item[0] for item in items)
|
||||
|
|
|
@ -7,14 +7,15 @@ from tqdm.contrib.concurrent import process_map
|
|||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.text import text2phone
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
|
||||
phonemizer = Gruut(language="en-us")
|
||||
|
||||
|
||||
def compute_phonemes(item):
|
||||
try:
|
||||
text = item[0]
|
||||
language = item[-1]
|
||||
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
|
||||
ph = phonemizer.phonemize(text).split("|")
|
||||
except:
|
||||
return []
|
||||
return list(set(ph))
|
||||
|
@ -39,10 +40,17 @@ def main():
|
|||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
|
||||
train_items, eval_items = load_tts_samples(
|
||||
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
items = train_items + eval_items
|
||||
print("Num items:", len(items))
|
||||
|
||||
is_lang_def = all(item["language"] for item in items)
|
||||
|
||||
if not c.phoneme_language or not is_lang_def:
|
||||
raise ValueError("Phoneme language must be defined in config.")
|
||||
|
||||
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||
phones = []
|
||||
for ph in phonemes:
|
||||
|
|
|
@ -26,6 +26,7 @@ if __name__ == "__main__":
|
|||
--input_dir /root/LJSpeech-1.1/
|
||||
--output_sr 22050
|
||||
--output_dir /root/resampled_LJSpeech-1.1/
|
||||
--file_ext wav
|
||||
--n_jobs 24
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
|
@ -55,6 +56,14 @@ if __name__ == "__main__":
|
|||
help="Path of the destination folder. If not defined, the operation is done in place",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file_ext",
|
||||
type=str,
|
||||
default="wav",
|
||||
required=False,
|
||||
help="Extension of the audio files to resample",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
|
||||
)
|
||||
|
@ -67,7 +76,7 @@ if __name__ == "__main__":
|
|||
args.input_dir = args.output_dir
|
||||
|
||||
print("Resampling the audio files...")
|
||||
audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True)
|
||||
audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True)
|
||||
print(f"Found {len(audio_files)} files...")
|
||||
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
|
||||
with Pool(processes=args.n_jobs) as p:
|
||||
|
|
|
@ -8,6 +8,7 @@ import traceback
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
|
@ -19,7 +20,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, check_update
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import os
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainTTSArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainingArgs()
|
||||
train_args = TrainTTSArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
|
@ -41,45 +44,15 @@ def main():
|
|||
config = register_config(config_base.model)()
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init speaker manager
|
||||
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
|
||||
speaker_manager = SpeakerManager(
|
||||
d_vectors_file_path=config.model_args.d_vector_file,
|
||||
encoder_model_path=config.model_args.speaker_encoder_model_path,
|
||||
encoder_config_path=config.model_args.speaker_encoder_config_path,
|
||||
use_cuda=torch.cuda.is_available(),
|
||||
)
|
||||
else:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
speaker_manager = None
|
||||
|
||||
if check_config_and_model_args(config, "use_language_embedding", True):
|
||||
language_manager = LanguageManager(config=config)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_languages = language_manager.num_languages
|
||||
else:
|
||||
config.num_languages = language_manager.num_languages
|
||||
else:
|
||||
language_manager = None
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
config.datasets,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config, speaker_manager, language_manager)
|
||||
model = setup_model(config, train_samples + eval_samples)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
|
@ -89,7 +62,6 @@ def main():
|
|||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
parse_command_line_args=False,
|
||||
)
|
||||
trainer.fit()
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainVocoderArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainingArgs()
|
||||
train_args = TrainVocoderArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
|
|
|
@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass
|
|||
from typing import List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
from trainer import TrainerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -57,6 +58,12 @@ class BaseAudioConfig(Coqpit):
|
|||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
pitch_fmax (float, optional):
|
||||
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||
|
||||
pitch_fmin (float, optional):
|
||||
Minimum frequency of the F0 frames. Defaults to ```0```.
|
||||
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
|
||||
|
@ -135,6 +142,9 @@ class BaseAudioConfig(Coqpit):
|
|||
spec_gain: int = 20
|
||||
do_amp_to_db_linear: bool = True
|
||||
do_amp_to_db_mel: bool = True
|
||||
# f0 params
|
||||
pitch_fmax: float = 640.0
|
||||
pitch_fmin: float = 0.0
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
|
@ -228,130 +238,24 @@ class BaseDatasetConfig(Coqpit):
|
|||
|
||||
|
||||
@dataclass
|
||||
class BaseTrainingConfig(Coqpit):
|
||||
"""Base config to define the basic training parameters that are shared
|
||||
among all the models.
|
||||
class BaseTrainingConfig(TrainerConfig):
|
||||
"""Base config to define the basic 🐸TTS training parameters that are shared
|
||||
among all the models. It is based on ```Trainer.TrainingConfig```.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
|
||||
run_name (str):
|
||||
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
|
||||
|
||||
run_description (str):
|
||||
Short description of the experiment.
|
||||
|
||||
epochs (int):
|
||||
Number training epochs. Defaults to 10000.
|
||||
|
||||
batch_size (int):
|
||||
Training batch size.
|
||||
|
||||
eval_batch_size (int):
|
||||
Validation batch size.
|
||||
|
||||
mixed_precision (bool):
|
||||
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
||||
it may also cause numerical unstability in some cases.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, run the scheduler step after each epoch else run it after each model step.
|
||||
|
||||
run_eval (bool):
|
||||
Enable / Disable evaluation (validation) run. Defaults to True.
|
||||
|
||||
test_delay_epochs (int):
|
||||
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
||||
results, hence waiting for a couple of epochs might save some time.
|
||||
|
||||
print_eval (bool):
|
||||
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
||||
the end of the evaluation. Default to ```False```.
|
||||
|
||||
print_step (int):
|
||||
Number of steps required to print the next training log.
|
||||
|
||||
log_dashboard (str): "tensorboard" or "wandb"
|
||||
Set the experiment tracking tool
|
||||
|
||||
plot_step (int):
|
||||
Number of steps required to log training on Tensorboard.
|
||||
|
||||
model_param_stats (bool):
|
||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||
Defaults to ```False```.
|
||||
|
||||
project_name (str):
|
||||
Name of the project. Defaults to config.model
|
||||
|
||||
wandb_entity (str):
|
||||
Name of W&B entity/team. Enables collaboration across a team or org.
|
||||
|
||||
log_model_step (int):
|
||||
Number of steps required to log a checkpoint as W&B artifact
|
||||
|
||||
save_step (int):ipt
|
||||
Number of steps required to save the next checkpoint.
|
||||
|
||||
checkpoint (bool):
|
||||
Enable / Disable checkpointing.
|
||||
|
||||
keep_all_best (bool):
|
||||
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
||||
to ```False```.
|
||||
|
||||
keep_after (int):
|
||||
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
||||
to 10000.
|
||||
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
|
||||
output_path (str):
|
||||
Path for training output folder, either a local file path or other
|
||||
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||
S3 (s3://) paths. The nonexist part of the given path is created
|
||||
automatically. All training artefacts are saved there.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
run_name: str = "coqui_tts"
|
||||
run_description: str = ""
|
||||
# training params
|
||||
epochs: int = 10000
|
||||
batch_size: int = None
|
||||
eval_batch_size: int = None
|
||||
mixed_precision: bool = False
|
||||
scheduler_after_epoch: bool = False
|
||||
# eval params
|
||||
run_eval: bool = True
|
||||
test_delay_epochs: int = 0
|
||||
print_eval: bool = False
|
||||
# logging
|
||||
dashboard_logger: str = "tensorboard"
|
||||
print_step: int = 25
|
||||
plot_step: int = 100
|
||||
model_param_stats: bool = False
|
||||
project_name: str = None
|
||||
log_model_step: int = None
|
||||
wandb_entity: str = None
|
||||
# checkpointing
|
||||
save_step: int = 10000
|
||||
checkpoint: bool = True
|
||||
keep_all_best: bool = False
|
||||
keep_after: int = 10000
|
||||
# dataloading
|
||||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
use_language_weighted_sampler: bool = False
|
||||
|
||||
# paths
|
||||
output_path: str = None
|
||||
# distributed
|
||||
distributed_backend: str = "nccl"
|
||||
distributed_url: str = "tcp://localhost:54321"
|
||||
|
|
101
TTS/model.py
101
TTS/model.py
|
@ -1,7 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -9,28 +8,21 @@ from torch import nn
|
|||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||
class BaseTrainerModel(ABC, nn.Module):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Init the model from given config.
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Set model arguments from the config. Override this."""
|
||||
pass
|
||||
Override this depending on your model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
||||
"""Forward pass for the model mainly used in training.
|
||||
"""Forward ... for the model mainly used in training.
|
||||
|
||||
You can be flexible here and use different number of arguments and argument names since it is intended to be
|
||||
used by `train_step()` without exposing it out of the model.
|
||||
|
@ -48,7 +40,7 @@ class BaseModel(nn.Module, ABC):
|
|||
|
||||
@abstractmethod
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
"""Forward ... for inference.
|
||||
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
|
@ -63,9 +55,25 @@ class BaseModel(nn.Module, ABC):
|
|||
...
|
||||
return outputs_dict
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Format batch returned by the data loader before sending it to the model.
|
||||
|
||||
If not implemented, model uses the batch as is.
|
||||
Can be used for data augmentation, feature ectraction, etc.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def format_batch_on_device(self, batch: Dict) -> Dict:
|
||||
"""Format batch on device before sending it to the model.
|
||||
|
||||
If not implemented, model uses the batch as is.
|
||||
Can be used for data augmentation, feature ectraction, etc.
|
||||
"""
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
"""Perform a single training step. Run the model forward ... and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
|
@ -93,11 +101,11 @@ class BaseModel(nn.Module, ABC):
|
|||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
|
||||
"""Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
|
||||
call `train_step()` with no changes.
|
||||
|
||||
Args:
|
||||
|
@ -114,36 +122,49 @@ class BaseModel(nn.Module, ABC):
|
|||
|
||||
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
|
||||
"""The same as `train_log()`"""
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
|
||||
"""Load a checkpoint and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||
"""Setup an return optimizer or optimizers."""
|
||||
pass
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
|
||||
"""Init the model from given config.
|
||||
|
||||
def get_lr(self) -> Union[float, List[float]]:
|
||||
"""Return learning rate(s).
|
||||
|
||||
Returns:
|
||||
Union[float, List[float]]: Model's initial learning rates.
|
||||
Override this depending on your model.
|
||||
"""
|
||||
pass
|
||||
...
|
||||
|
||||
def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
||||
pass
|
||||
@abstractmethod
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
...
|
||||
|
||||
def get_criterion(self):
|
||||
pass
|
||||
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||
# """Setup an return optimizer or optimizers."""
|
||||
# ...
|
||||
|
||||
def format_batch(self):
|
||||
pass
|
||||
# def get_lr(self) -> Union[float, List[float]]:
|
||||
# """Return learning rate(s).
|
||||
|
||||
# Returns:
|
||||
# Union[float, List[float]]: Model's initial learning rates.
|
||||
# """
|
||||
# ...
|
||||
|
||||
# def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
||||
# ...
|
||||
|
||||
# def get_criterion(self):
|
||||
# ...
|
||||
|
|
|
@ -88,7 +88,7 @@ if args.model_name is not None and not args.model_path:
|
|||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
# CASE3: set custome model paths
|
||||
# CASE3: set custom model paths
|
||||
if args.model_path is not None:
|
||||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
|
@ -170,9 +170,9 @@ def tts():
|
|||
text = request.args.get("text")
|
||||
speaker_idx = request.args.get("speaker_id", "")
|
||||
style_wav = request.args.get("style_wav", "")
|
||||
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
print(" > Model input: {}".format(text))
|
||||
print(" > Speaker Idx: {}".format(speaker_idx))
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
|
|
@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
|
|||
mel = self.ap.melspectrogram(wav).astype("float32")
|
||||
# sample seq_len
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
assert text.size > 0, self.items[idx]["audio_file"]
|
||||
assert wav.size > 0, self.items[idx]["audio_file"]
|
||||
|
||||
sample = {
|
||||
"mel": mel,
|
||||
"item_idx": self.items[idx][1],
|
||||
"item_idx": self.items[idx]["audio_file"],
|
||||
"speaker_name": speaker_name,
|
||||
}
|
||||
return sample
|
||||
|
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
|
|||
def __parse_items(self):
|
||||
self.speaker_to_utters = {}
|
||||
for i in self.items:
|
||||
path_ = i[1]
|
||||
speaker_ = i[2]
|
||||
path_ = i["audio_file"]
|
||||
speaker_ = i["speaker_name"]
|
||||
if speaker_ in self.speaker_to_utters.keys():
|
||||
self.speaker_to_utters[speaker_].append(path_)
|
||||
else:
|
||||
|
|
|
@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
|
|
|
@ -113,7 +113,7 @@ class AugmentWAV(object):
|
|||
|
||||
def additive_noise(self, noise_type, audio):
|
||||
|
||||
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
|
||||
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
||||
|
||||
noise_list = random.sample(
|
||||
self.noise_list[noise_type],
|
||||
|
@ -135,7 +135,7 @@ class AugmentWAV(object):
|
|||
self.additive_noise_config[noise_type]["min_snr_in_db"],
|
||||
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||
)
|
||||
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
|
||||
noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
|
||||
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
||||
|
||||
if noises_wav is None:
|
||||
|
@ -154,7 +154,7 @@ class AugmentWAV(object):
|
|||
|
||||
rir_file = random.choice(self.rir_files)
|
||||
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
||||
rir = rir / np.sqrt(np.sum(rir ** 2))
|
||||
rir = rir / np.sqrt(np.sum(rir**2))
|
||||
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
|
||||
|
||||
def apply_one(self, audio):
|
||||
|
|
|
@ -1,19 +1,24 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import TrainingArgs
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.logging import init_dashboard_logger
|
||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
@ -75,13 +80,13 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
dashboard_logger = logger_factory(config, experiment_path)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
|
1199
TTS/trainer.py
1199
TTS/trainer.py
File diff suppressed because it is too large
Load Diff
|
@ -89,11 +89,11 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
pitch_loss_alpha (float):
|
||||
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||
|
||||
binary_loss_alpha (float):
|
||||
binary_align_loss_alpha (float):
|
||||
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||
|
||||
binary_align_loss_start_step (int):
|
||||
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||
binary_loss_warmup_epochs (float):
|
||||
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
@ -129,12 +129,12 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
duration_loss_type: str = "mse"
|
||||
use_ssim_loss: bool = True
|
||||
ssim_loss_alpha: float = 1.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
spec_loss_alpha: float = 1.0
|
||||
pitch_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
binary_align_loss_alpha: float = 1.0
|
||||
binary_align_loss_start_step: int = 20000
|
||||
pitch_loss_alpha: float = 0.1
|
||||
dur_loss_alpha: float = 0.1
|
||||
binary_align_loss_alpha: float = 0.1
|
||||
binary_loss_warmup_epochs: int = 150
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
|
|
|
@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig):
|
|||
binary_loss_alpha (float):
|
||||
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||
|
||||
binary_align_loss_start_step (int):
|
||||
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||
binary_loss_warmup_epochs (float):
|
||||
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig):
|
|||
pitch_loss_alpha: float = 0.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
binary_align_loss_alpha: float = 1.0
|
||||
binary_align_loss_start_step: int = 20000
|
||||
binary_loss_warmup_epochs: int = 150
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
|
|
|
@ -153,6 +153,7 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
|
||||
|
@ -50,9 +50,16 @@ class GSTConfig(Coqpit):
|
|||
|
||||
@dataclass
|
||||
class CharactersConfig(Coqpit):
|
||||
"""Defines character or phoneme set used by the model
|
||||
"""Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses.
|
||||
|
||||
Args:
|
||||
characters_class (str):
|
||||
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
|
||||
the configuration. Defaults to None.
|
||||
|
||||
vocab_dict (dict):
|
||||
Defines the vocabulary dictionary used to encode the characters. Defaults to None.
|
||||
|
||||
pad (str):
|
||||
characters in place of empty padding. Defaults to None.
|
||||
|
||||
|
@ -62,6 +69,9 @@ class CharactersConfig(Coqpit):
|
|||
bos (str):
|
||||
characters showing the beginning of a sentence. Defaults to None.
|
||||
|
||||
blank (str):
|
||||
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
|
||||
|
||||
characters (str):
|
||||
character set used by the model. Characters not in this list are ignored when converting input text to
|
||||
a list of sequence IDs. Defaults to None.
|
||||
|
@ -70,32 +80,32 @@ class CharactersConfig(Coqpit):
|
|||
characters considered as punctuation as parsing the input sentence. Defaults to None.
|
||||
|
||||
phonemes (str):
|
||||
characters considered as parsing phonemes. Defaults to None.
|
||||
characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new
|
||||
models. Defaults to None.
|
||||
|
||||
unique (bool):
|
||||
is_unique (bool):
|
||||
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
|
||||
models trained with character lists with duplicates.
|
||||
models trained with character lists with duplicates. Defaults to True.
|
||||
|
||||
is_sorted (bool):
|
||||
Sort the characters in alphabetical order. Defaults to True.
|
||||
"""
|
||||
|
||||
characters_class: str = None
|
||||
|
||||
# using BaseVocabulary
|
||||
vocab_dict: Dict = None
|
||||
|
||||
# using on BaseCharacters
|
||||
pad: str = None
|
||||
eos: str = None
|
||||
bos: str = None
|
||||
blank: str = None
|
||||
characters: str = None
|
||||
punctuations: str = None
|
||||
phonemes: str = None
|
||||
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("pad", c, prerequest="characters", restricted=True)
|
||||
check_argument("eos", c, prerequest="characters", restricted=True)
|
||||
check_argument("bos", c, prerequest="characters", restricted=True)
|
||||
check_argument("characters", c, prerequest="characters", restricted=True)
|
||||
check_argument("phonemes", c, restricted=True)
|
||||
check_argument("punctuations", c, prerequest="characters", restricted=True)
|
||||
is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
|
||||
is_sorted: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -110,8 +120,13 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
use_phonemes (bool):
|
||||
enable / disable phoneme use.
|
||||
|
||||
use_espeak_phonemes (bool):
|
||||
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
|
||||
phonemizer (str):
|
||||
Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`.
|
||||
Defaults to None.
|
||||
|
||||
phoneme_language (str):
|
||||
Language code for the phonemizer. You can check the list of supported languages by running
|
||||
`python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None.
|
||||
|
||||
compute_input_seq_cache (bool):
|
||||
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
|
||||
|
@ -144,11 +159,19 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequence length to be used at training.
|
||||
min_text_len (int):
|
||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
max_text_len (int):
|
||||
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||
|
||||
min_audio_len (int):
|
||||
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_audio_len (int):
|
||||
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||
OOM error in training. Defaults to float("inf").
|
||||
|
||||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
@ -156,9 +179,16 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
compute_linear_spec (bool):
|
||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers to precompute features. Defaults to 0.
|
||||
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise.
|
||||
|
||||
start_by_longest (bool):
|
||||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||
Defaults to False.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
@ -183,12 +213,19 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
||||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# phoneme settings
|
||||
use_phonemes: bool = False
|
||||
use_espeak_phonemes: bool = True
|
||||
phonemizer: str = None
|
||||
phoneme_language: str = None
|
||||
compute_input_seq_cache: bool = False
|
||||
text_cleaner: str = None
|
||||
|
@ -197,17 +234,21 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
phoneme_cache_path: str = None
|
||||
# vocabulary parameters
|
||||
characters: CharactersConfig = None
|
||||
add_blank: bool = False
|
||||
# training params
|
||||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
sort_by_audio_len: bool = False
|
||||
min_seq_len: int = 1
|
||||
max_seq_len: int = float("inf")
|
||||
min_audio_len: int = 1
|
||||
max_audio_len: int = float("inf")
|
||||
min_text_len: int = 1
|
||||
max_text_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
add_blank: bool = False
|
||||
start_by_longest: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
|
@ -218,3 +259,6 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
# evaluation
|
||||
eval_split_max_size: int = None
|
||||
eval_split_size: float = 0.01
|
||||
|
|
|
@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
binary_loss_alpha (float):
|
||||
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||
|
||||
binary_align_loss_start_step (int):
|
||||
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||
binary_loss_warmup_epochs (float):
|
||||
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
spec_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
binary_align_loss_alpha: float = 0.3
|
||||
binary_align_loss_start_step: int = 50000
|
||||
binary_loss_warmup_epochs: int = 150
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
|
|
|
@ -83,6 +83,8 @@ class TacotronConfig(BaseTTSConfig):
|
|||
ddc_r (int):
|
||||
reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this
|
||||
as a multiple of the `r` value. Defaults to 6.
|
||||
speakers_file (str):
|
||||
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
|
@ -176,6 +178,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
ddc_r: int = 6
|
||||
|
||||
# multi-speaker settings
|
||||
speakers_file: str = None
|
||||
use_speaker_embedding: bool = False
|
||||
speaker_embedding_dim: int = 512
|
||||
use_d_vector_file: bool = False
|
||||
|
|
|
@ -17,7 +17,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator. Defaults to 0.0002.
|
||||
|
@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequnce length to be considered for training. Defaults to `0`.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequnce length to be considered for training. Defaults to `500000`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
|
@ -130,9 +121,6 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 0
|
||||
max_seq_len: int = 500000
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
|
|
|
@ -9,25 +9,48 @@ from TTS.tts.datasets.dataset import *
|
|||
from TTS.tts.datasets.formatters import *
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
||||
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
|
||||
|
||||
Args:
|
||||
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
|
||||
Args:
|
||||
<<<<<<< HEAD
|
||||
items (List[List]):
|
||||
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
||||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
=======
|
||||
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
|
||||
>>>>>>> Fix docstring
|
||||
"""
|
||||
speakers = [item[-1] for item in items]
|
||||
speakers = [item["speaker_name"] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = min(500, int(len(items) * 0.01))
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
if eval_split_size > 1:
|
||||
eval_split_size = int(eval_split_size)
|
||||
else:
|
||||
if eval_split_max_size:
|
||||
eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size))
|
||||
else:
|
||||
eval_split_size = int(len(items) * eval_split_size)
|
||||
|
||||
assert (
|
||||
eval_split_size > 0
|
||||
), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(
|
||||
1 / len(items)
|
||||
)
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
speakers = [item[-1] for item in items]
|
||||
speakers = [item["speaker_name"] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
while len(items_eval) < eval_split_size:
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
speaker_to_be_removed = items[item_idx][-1]
|
||||
speaker_to_be_removed = items[item_idx]["speaker_name"]
|
||||
if speaker_counter[speaker_to_be_removed] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
speaker_counter[speaker_to_be_removed] -= 1
|
||||
|
@ -37,7 +60,11 @@ def split_dataset(items):
|
|||
|
||||
|
||||
def load_tts_samples(
|
||||
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None
|
||||
datasets: Union[List[Dict], Dict],
|
||||
eval_split=True,
|
||||
formatter: Callable = None,
|
||||
eval_split_max_size=None,
|
||||
eval_split_size=0.01,
|
||||
) -> Tuple[List[List], List[List]]:
|
||||
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
|
||||
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
|
||||
|
@ -52,9 +79,16 @@ def load_tts_samples(
|
|||
|
||||
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
|
||||
must take the root_path and the meta_file name and return a list of samples in the format of
|
||||
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
|
||||
`[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
|
||||
example. Defaults to None.
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
||||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
|
||||
Returns:
|
||||
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||
"""
|
||||
|
@ -75,28 +109,28 @@ def load_tts_samples(
|
|||
formatter = _get_formatter_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for the duration predictor training
|
||||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
attn_file = meta_data[ins["audio_file"]].strip()
|
||||
meta_data_train_all[idx].update({"alignment_file": attn_file})
|
||||
if meta_data_eval_all:
|
||||
for idx, ins in enumerate(meta_data_eval_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_eval_all[idx].append(attn_file)
|
||||
attn_file = meta_data[ins["audio_file"]].strip()
|
||||
meta_data_eval_all[idx].update({"alignment_file": attn_file})
|
||||
# set none for the next iter
|
||||
formatter = None
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import collections
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -10,87 +9,99 @@ import tqdm
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# to prevent too many open files error as suggested here
|
||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def _parse_sample(item):
|
||||
language_name = None
|
||||
attn_file = None
|
||||
if len(item) == 5:
|
||||
text, wav_file, speaker_name, language_name, attn_file = item
|
||||
elif len(item) == 4:
|
||||
text, wav_file, speaker_name, language_name = item
|
||||
elif len(item) == 3:
|
||||
text, wav_file, speaker_name = item
|
||||
else:
|
||||
raise ValueError(" [!] Dataset cannot parse the sample.")
|
||||
return text, wav_file, speaker_name, language_name, attn_file
|
||||
|
||||
|
||||
def noise_augment_audio(wav):
|
||||
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
outputs_per_step: int,
|
||||
text_cleaner: list,
|
||||
compute_linear_spec: bool,
|
||||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
outputs_per_step: int = 1,
|
||||
compute_linear_spec: bool = False,
|
||||
ap: AudioProcessor = None,
|
||||
samples: List[Dict] = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
compute_f0: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
characters: Dict = None,
|
||||
custom_symbols: List = None,
|
||||
add_blank: bool = False,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
use_phonemes: bool = False,
|
||||
min_text_len: int = 0,
|
||||
max_text_len: int = float("inf"),
|
||||
min_audio_len: int = 0,
|
||||
max_audio_len: int = float("inf"),
|
||||
phoneme_cache_path: str = None,
|
||||
phoneme_language: str = "en-us",
|
||||
enable_eos_bos: bool = False,
|
||||
precompute_num_workers: int = 0,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
start_by_longest: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
If you need something different, you can inherit and override.
|
||||
If you need something different, you can subclass and override.
|
||||
|
||||
Args:
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
meta_data (list): List of dataset instances.
|
||||
samples (list): List of dataset samples.
|
||||
|
||||
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
||||
use the given. Defaults to None.
|
||||
|
||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||
|
||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||
set of symbols need to pass it here. Defaults to `None`.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
||||
min_seq_len (int): Minimum input sequence length to be processed
|
||||
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
|
||||
minimum input length due to its architecture. Defaults to 0.
|
||||
min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
|
||||
Defaults to 0.
|
||||
|
||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
||||
RNN layers are sensitive to input length. Defaults to `Inf`.
|
||||
max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
|
||||
Defaults to float("inf").
|
||||
|
||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
|
||||
Defaults to 0.
|
||||
|
||||
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
|
||||
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
|
||||
this value if you encounter an OOM error in training. Defaults to float("inf").
|
||||
|
||||
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
||||
separate file. Defaults to None.
|
||||
|
||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||
|
||||
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
|
||||
to False.
|
||||
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
|
||||
|
||||
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||
embedding layer. Defaults to None.
|
||||
|
@ -99,285 +110,254 @@ class TTSDataset(Dataset):
|
|||
|
||||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
|
||||
|
||||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
self.items = meta_data
|
||||
self._samples = samples
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.compute_f0 = compute_f0
|
||||
self.f0_cache_path = f0_cache_path
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.min_audio_len = min_audio_len
|
||||
self.max_audio_len = max_audio_len
|
||||
self.min_text_len = min_text_len
|
||||
self.max_text_len = max_text_len
|
||||
self.ap = ap
|
||||
self.characters = characters
|
||||
self.custom_symbols = custom_symbols
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.language_id_mapping = language_id_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.start_by_longest = start_by_longest
|
||||
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
self.pitch_computed = False
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if self.tokenizer.use_phonemes:
|
||||
self.phoneme_dataset = PhonemeDataset(
|
||||
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if compute_f0:
|
||||
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
|
||||
self.f0_dataset = F0Dataset(
|
||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||
if use_phonemes:
|
||||
print(" | > phoneme language: {}".format(phoneme_language))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.print_logs()
|
||||
|
||||
@property
|
||||
def lengths(self):
|
||||
lens = []
|
||||
for item in self.samples:
|
||||
_, wav_file, *_ = _parse_sample(item)
|
||||
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
lens.append(audio_len)
|
||||
return lens
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
return self._samples
|
||||
|
||||
@samples.setter
|
||||
def samples(self, new_samples):
|
||||
self._samples = new_samples
|
||||
if hasattr(self, "f0_dataset"):
|
||||
self.f0_dataset.samples = new_samples
|
||||
if hasattr(self, "phoneme_dataset"):
|
||||
self.phoneme_dataset.samples = new_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.load_data(idx)
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> DataLoader initialization")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename)
|
||||
return audio
|
||||
waveform = self.ap.load_wav(filename)
|
||||
assert waveform.size > 0
|
||||
return waveform
|
||||
|
||||
def get_phonemes(self, idx, text):
|
||||
out_dict = self.phoneme_dataset[idx]
|
||||
assert text == out_dict["text"], f"{text} != {out_dict['text']}"
|
||||
assert len(out_dict["token_ids"]) > 0
|
||||
return out_dict
|
||||
|
||||
def get_f0(self, idx):
|
||||
out_dict = self.f0_dataset[idx]
|
||||
item = self.samples[idx]
|
||||
assert item["audio_file"] == out_dict["audio_file"]
|
||||
return out_dict
|
||||
|
||||
@staticmethod
|
||||
def load_np(filename):
|
||||
data = np.load(filename).astype("float32")
|
||||
return data
|
||||
def get_attn_mask(attn_file):
|
||||
return np.load(attn_file)
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text,
|
||||
[cleaners],
|
||||
language=language,
|
||||
enable_eos_bos=False,
|
||||
custom_symbols=custom_symbols,
|
||||
tp=characters,
|
||||
add_blank=add_blank,
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
return phonemes
|
||||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
# different names for normal phonemes and with blank chars.
|
||||
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
|
||||
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
|
||||
try:
|
||||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
return phonemes
|
||||
def get_token_ids(self, idx, text):
|
||||
if self.tokenizer.use_phonemes:
|
||||
token_ids = self.get_phonemes(idx, text)["token_ids"]
|
||||
else:
|
||||
token_ids = self.tokenizer.text_to_ids(text)
|
||||
return np.array(token_ids, dtype=np.int32)
|
||||
|
||||
def load_data(self, idx):
|
||||
item = self.items[idx]
|
||||
item = self.samples[idx]
|
||||
|
||||
if len(item) == 5:
|
||||
text, wav_file, speaker_name, language_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name, language_name = item
|
||||
attn = None
|
||||
raw_text = text
|
||||
raw_text = item["text"]
|
||||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
|
||||
|
||||
# apply noise for augmentation
|
||||
if self.use_noise_augment:
|
||||
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||
wav = noise_augment_audio(wav)
|
||||
|
||||
if not self.input_seq_computed:
|
||||
if self.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(
|
||||
wav_file,
|
||||
text,
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
language_name if language_name else self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
)
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
# get token ids
|
||||
token_ids = self.get_token_ids(idx, item["text"])
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
# get pre-computed attention maps
|
||||
attn = None
|
||||
if "alignment_file" in item:
|
||||
attn = self.get_attn_mask(item["alignment_file"])
|
||||
|
||||
if "attn_file" in locals():
|
||||
attn = np.load(attn_file)
|
||||
|
||||
if len(text) > self.max_seq_len:
|
||||
# return a different sample if the phonemized
|
||||
# text is longer than the threshold
|
||||
# TODO: find a better fix
|
||||
# after phonemization the text length may change
|
||||
# this is a shareful 🤭 hack to prevent longer phonemes
|
||||
# TODO: find a better fix
|
||||
if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
|
||||
self.rescue_item_idx += 1
|
||||
return self.load_data(self.rescue_item_idx)
|
||||
|
||||
pitch = None
|
||||
# get f0 values
|
||||
f0 = None
|
||||
if self.compute_f0:
|
||||
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
||||
f0 = self.get_f0(idx)["f0"]
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
"text": text,
|
||||
"token_ids": token_ids,
|
||||
"wav": wav,
|
||||
"pitch": pitch,
|
||||
"pitch": f0,
|
||||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
"language_name": language_name,
|
||||
"wav_file_name": os.path.basename(wav_file),
|
||||
"item_idx": item["audio_file"],
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
"wav_file_name": os.path.basename(item["audio_file"]),
|
||||
}
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def _phoneme_worker(args):
|
||||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
func_args[3] = (
|
||||
item[3] if item[3] else func_args[3]
|
||||
) # override phoneme language if specified by the dataset formatter
|
||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
return phonemes
|
||||
def _compute_lengths(samples):
|
||||
new_samples = []
|
||||
for item in samples:
|
||||
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
|
||||
text_lenght = len(item["text"])
|
||||
item["audio_length"] = audio_length
|
||||
item["text_length"] = text_lenght
|
||||
new_samples += [item]
|
||||
return new_samples
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not self.use_phonemes:
|
||||
if self.verbose:
|
||||
print(" | > Computing input sequences ...")
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
else:
|
||||
func_args = [
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
if self.verbose:
|
||||
print(" | > Computing phonemes ...")
|
||||
if num_workers == 0:
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
phonemes = self._phoneme_worker([item, func_args])
|
||||
self.items[idx][0] = phonemes
|
||||
@staticmethod
|
||||
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
|
||||
idxs = np.argsort(lengths) # ascending order
|
||||
ignore_idx = []
|
||||
keep_idx = []
|
||||
for idx in idxs:
|
||||
length = lengths[idx]
|
||||
if length < min_len or length > max_len:
|
||||
ignore_idx.append(idx)
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
phonemes = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
keep_idx.append(idx)
|
||||
return ignore_idx, keep_idx
|
||||
|
||||
def sort_and_filter_items(self, by_audio_len=False):
|
||||
@staticmethod
|
||||
def sort_by_length(samples: List[List]):
|
||||
audio_lengths = [s["audio_length"] for s in samples]
|
||||
idxs = np.argsort(audio_lengths) # ascending order
|
||||
return idxs
|
||||
|
||||
@staticmethod
|
||||
def create_buckets(samples, batch_group_size: int):
|
||||
assert batch_group_size > 0
|
||||
for i in range(len(samples) // batch_group_size):
|
||||
offset = i * batch_group_size
|
||||
end_offset = offset + batch_group_size
|
||||
temp_items = samples[offset:end_offset]
|
||||
random.shuffle(temp_items)
|
||||
samples[offset:end_offset] = temp_items
|
||||
return samples
|
||||
|
||||
@staticmethod
|
||||
def _select_samples_by_idx(idxs, samples):
|
||||
samples_new = []
|
||||
for idx in idxs:
|
||||
samples_new.append(samples[idx])
|
||||
return samples_new
|
||||
|
||||
def preprocess_samples(self):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
|
||||
Args:
|
||||
by_audio_len (bool): if True, sort by audio length else by text length.
|
||||
"""
|
||||
# compute the target sequence length
|
||||
if by_audio_len:
|
||||
lengths = []
|
||||
for item in self.items:
|
||||
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio
|
||||
lengths = np.array(lengths)
|
||||
else:
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
samples = self._compute_lengths(self.samples)
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
text_lengths = [i["text_length"] for i in samples]
|
||||
audio_lengths = [i["audio_length"] for i in samples]
|
||||
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
|
||||
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
||||
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||
|
||||
samples = self._select_samples_by_idx(keep_idx, samples)
|
||||
|
||||
sorted_idxs = self.sort_by_length(samples)
|
||||
|
||||
if self.start_by_longest:
|
||||
longest_idxs = sorted_idxs[-1]
|
||||
sorted_idxs[-1] = sorted_idxs[0]
|
||||
sorted_idxs[0] = longest_idxs
|
||||
|
||||
samples = self._select_samples_by_idx(sorted_idxs, samples)
|
||||
|
||||
if len(samples) == 0:
|
||||
raise RuntimeError(" [!] No samples left")
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len or length > self.max_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_items.append(self.items[idx])
|
||||
# shuffle batch groups
|
||||
# create batches with similar length items
|
||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||
if self.batch_group_size > 0:
|
||||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_items = new_items[offset:end_offset]
|
||||
random.shuffle(temp_items)
|
||||
new_items[offset:end_offset] = temp_items
|
||||
self.items = new_items
|
||||
samples = self.create_buckets(samples, self.batch_group_size)
|
||||
|
||||
# update items to the new sorted items
|
||||
audio_lengths = [s["audio_length"] for s in samples]
|
||||
text_lengths = [s["text_length"] for s in samples]
|
||||
self.samples = samples
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||
print(
|
||||
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
|
||||
self.max_seq_len, self.min_seq_len, len(ignored)
|
||||
)
|
||||
)
|
||||
print(" | > Preprocessing samples")
|
||||
print(" | > Max text length: {}".format(np.max(text_lengths)))
|
||||
print(" | > Min text length: {}".format(np.min(text_lengths)))
|
||||
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
|
||||
print(" | ")
|
||||
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
|
||||
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
|
||||
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
|
||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.load_data(idx)
|
||||
|
||||
@staticmethod
|
||||
def _sort_batch(batch, text_lengths):
|
||||
"""Sort the batch by the input text length for RNN efficiency.
|
||||
|
@ -402,10 +382,10 @@ class TTSDataset(Dataset):
|
|||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.abc.Mapping):
|
||||
|
||||
text_lengths = np.array([len(d["text"]) for d in batch])
|
||||
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
|
||||
|
||||
# sort items with text input length for RNN efficiency
|
||||
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
|
||||
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
|
||||
|
||||
# convert list of dicts to dict of lists
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
@ -447,7 +427,7 @@ class TTSDataset(Dataset):
|
|||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||
|
||||
# PAD sequences with longest instance in the batch
|
||||
text = prepare_data(batch["text"]).astype(np.int32)
|
||||
token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
|
||||
|
||||
# PAD features with longest instance
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
|
@ -456,12 +436,13 @@ class TTSDataset(Dataset):
|
|||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lengths = torch.LongTensor(text_lengths)
|
||||
text = torch.LongTensor(text)
|
||||
token_ids_lengths = torch.LongTensor(token_ids_lengths)
|
||||
token_ids = torch.LongTensor(token_ids)
|
||||
mel = torch.FloatTensor(mel).contiguous()
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
# speaker vectors
|
||||
if d_vectors is not None:
|
||||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
|
@ -472,14 +453,13 @@ class TTSDataset(Dataset):
|
|||
language_ids = torch.LongTensor(language_ids)
|
||||
|
||||
# compute linear spectrogram
|
||||
linear = None
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
else:
|
||||
linear = None
|
||||
|
||||
# format waveforms
|
||||
wav_padded = None
|
||||
|
@ -495,8 +475,7 @@ class TTSDataset(Dataset):
|
|||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||
wav_padded.transpose_(1, 2)
|
||||
|
||||
# compute f0
|
||||
# TODO: compare perf in collate_fn vs in load_data
|
||||
# format F0
|
||||
if self.compute_f0:
|
||||
pitch = prepare_data(batch["pitch"])
|
||||
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||
|
@ -504,23 +483,22 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
pitch = None
|
||||
|
||||
# collate attention alignments
|
||||
# format attention masks
|
||||
attns = None
|
||||
if batch["attn"][0] is not None:
|
||||
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
pad1 = token_ids.shape[1] - attn.shape[0]
|
||||
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||
attns[idx] = attn
|
||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||
else:
|
||||
attns = None
|
||||
# TODO: return dictionary
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
"token_id": token_ids,
|
||||
"token_id_lengths": token_ids_lengths,
|
||||
"speaker_names": batch["speaker_name"],
|
||||
"linear": linear,
|
||||
"mel": mel,
|
||||
|
@ -546,22 +524,185 @@ class TTSDataset(Dataset):
|
|||
)
|
||||
|
||||
|
||||
class PitchExtractor:
|
||||
"""Pitch Extractor for computing F0 from wav files.
|
||||
class PhonemeDataset(Dataset):
|
||||
"""Phoneme Dataset for converting input text to phonemes and then token IDs
|
||||
|
||||
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
|
||||
loading latency. If `cache_path` is already present, it skips the pre-computation.
|
||||
|
||||
Args:
|
||||
items (List[List]): Dataset samples.
|
||||
verbose (bool): Whether to print the progress.
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
tokenizer (TTSTokenizer):
|
||||
Tokenizer to convert input text to phonemes.
|
||||
|
||||
cache_path (str):
|
||||
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers used for pre-computing the phonemes. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: List[List],
|
||||
verbose=False,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
tokenizer: "TTSTokenizer",
|
||||
cache_path: str,
|
||||
precompute_num_workers=0,
|
||||
):
|
||||
self.items = items
|
||||
self.samples = samples
|
||||
self.tokenizer = tokenizer
|
||||
self.cache_path = cache_path
|
||||
if cache_path is not None and not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path)
|
||||
self.precompute(precompute_num_workers)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.samples[index]
|
||||
ids = self.compute_or_load(item["audio_file"], item["text"])
|
||||
ph_hat = self.tokenizer.ids_to_text(ids)
|
||||
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def compute_or_load(self, wav_file, text):
|
||||
"""Compute phonemes for the given text.
|
||||
|
||||
If the phonemes are already cached, load them from cache.
|
||||
"""
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
file_ext = "_phoneme.npy"
|
||||
cache_path = os.path.join(self.cache_path, file_name + file_ext)
|
||||
try:
|
||||
ids = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
ids = self.tokenizer.text_to_ids(text)
|
||||
np.save(cache_path, ids)
|
||||
return ids
|
||||
|
||||
def get_pad_id(self):
|
||||
"""Get pad token ID for sequence padding"""
|
||||
return self.tokenizer.pad_id
|
||||
|
||||
def precompute(self, num_workers=1):
|
||||
"""Precompute phonemes for all samples.
|
||||
|
||||
We use pytorch dataloader because we are lazy.
|
||||
"""
|
||||
print("[*] Pre-computing phonemes...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
)
|
||||
for _ in dataloder:
|
||||
pbar.update(batch_size)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
ids = [item["token_ids"] for item in batch]
|
||||
ids_lens = [item["token_ids_len"] for item in batch]
|
||||
texts = [item["text"] for item in batch]
|
||||
texts_hat = [item["ph_hat"] for item in batch]
|
||||
ids_lens_max = max(ids_lens)
|
||||
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
|
||||
for i, ids_len in enumerate(ids_lens):
|
||||
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
|
||||
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> PhonemeDataset ")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
||||
|
||||
class F0Dataset:
|
||||
"""F0 Dataset for computing F0 from wav files in CPU
|
||||
|
||||
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||
also computes the mean and std of F0 values if `normalize_f0` is True.
|
||||
|
||||
Args:
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
ap (AudioProcessor):
|
||||
AudioProcessor to compute F0 from wav files.
|
||||
|
||||
cache_path (str):
|
||||
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
|
||||
Defaults to None.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers used for pre-computing the F0 values. Defaults to 0.
|
||||
|
||||
normalize_f0 (bool):
|
||||
Whether to normalize F0 values by mean and std. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
ap: "AudioProcessor",
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_f0=True,
|
||||
):
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.cache_path = cache_path
|
||||
self.normalize_f0 = normalize_f0
|
||||
self.pad_id = 0.0
|
||||
self.mean = None
|
||||
self.std = None
|
||||
if cache_path is not None and not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path)
|
||||
self.precompute(precompute_num_workers)
|
||||
if normalize_f0:
|
||||
self.load_stats(cache_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
f0 = self.compute_or_load(item["audio_file"])
|
||||
if self.normalize_f0:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
f0 = self.normalize(f0)
|
||||
return {"audio_file": item["audio_file"], "f0": f0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing F0s...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
normalize_f0 = self.normalize_f0
|
||||
self.normalize_f0 = False
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
)
|
||||
computed_data = []
|
||||
for batch in dataloder:
|
||||
f0 = batch["f0"]
|
||||
computed_data.append(f for f in f0)
|
||||
pbar.update(batch_size)
|
||||
self.normalize_f0 = normalize_f0
|
||||
|
||||
if self.normalize_f0:
|
||||
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
|
||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||
|
||||
def get_pad_id(self):
|
||||
return self.pad_id
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(wav_file, cache_path):
|
||||
|
@ -583,70 +724,49 @@ class PitchExtractor:
|
|||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def normalize_pitch(self, pitch):
|
||||
def load_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
self.std = stats["std"].astype(np.float32)
|
||||
|
||||
def normalize(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch = pitch - self.mean
|
||||
pitch = pitch / self.std
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
def denormalize_pitch(self, pitch):
|
||||
def denormalize(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch *= self.std
|
||||
pitch += self.mean
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
def compute_or_load(self, wav_file):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch.astype(np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _pitch_worker(args):
|
||||
item = args[0]
|
||||
ap = args[1]
|
||||
cache_path = args[2]
|
||||
_, wav_file, *_ = item
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
return pitch
|
||||
return None
|
||||
def collate_fn(self, batch):
|
||||
audio_file = [item["audio_file"] for item in batch]
|
||||
f0s = [item["f0"] for item in batch]
|
||||
f0_lens = [len(item["f0"]) for item in batch]
|
||||
f0_lens_max = max(f0_lens)
|
||||
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
|
||||
for i, f0_len in enumerate(f0_lens):
|
||||
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
|
||||
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
|
||||
|
||||
def compute_pitch(self, ap, cache_path, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Computing pitch features ...")
|
||||
if num_workers == 0:
|
||||
pitch_vecs = []
|
||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
pitch_vecs = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||
|
||||
def load_pitch_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
self.std = stats["std"].astype(np.float32)
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> F0Dataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
|
|
@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
wav_file = cols[1].strip()
|
||||
text = cols[0].strip()
|
||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
|||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
|
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
speaker_id = 0
|
||||
for idx, line in enumerate(ttf):
|
||||
# 2 samples per speaker to avoid eval split issues
|
||||
if idx % 2 == 0:
|
||||
speaker_id += 1
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -150,7 +154,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
continue
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -165,7 +169,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -179,7 +183,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -193,7 +197,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
utt_id = line.split()[1]
|
||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -213,7 +217,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
|||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append([text, wav_file, "MCV_" + speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -240,7 +244,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
||||
for item in items:
|
||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
return items
|
||||
|
@ -259,7 +263,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
skipped_files.append(wav_file)
|
||||
continue
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
@ -281,12 +285,32 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, speaker_id])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
||||
return items
|
||||
|
||||
|
||||
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None):
|
||||
"""VCTK dataset v0.92.
|
||||
|
||||
URL:
|
||||
https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
|
||||
|
||||
This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```.
|
||||
It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset.
|
||||
|
||||
mic1:
|
||||
Audio recorded using an omni-directional microphone (DPA 4035).
|
||||
Contains very low frequency noises.
|
||||
This is the same audio released in previous versions of VCTK:
|
||||
https://doi.org/10.7488/ds/1994
|
||||
|
||||
mic2:
|
||||
Audio recorded using a small diaphragm condenser microphone with
|
||||
very wide bandwidth (Sennheiser MKH 800).
|
||||
Two speakers, p280 and p315 had technical issues of the audio
|
||||
recordings using MKH 800.
|
||||
"""
|
||||
file_ext = "flac"
|
||||
items = []
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
|
@ -298,26 +322,33 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
|||
continue
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
|
||||
# p280 has no mic2 recordings
|
||||
if speaker_id == "p280":
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}")
|
||||
else:
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||
if os.path.exists(wav_file):
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
return items
|
||||
|
||||
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
|
||||
def vctk_old(root_path, meta_files=None, wavs_path="wav48"):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
test_speakers = meta_files
|
||||
items = []
|
||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for text_file in txt_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
continue
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([None, wav_file, "VCTK_" + speaker_id])
|
||||
|
||||
items.append([text, wav_file, "VCTK_old_" + speaker_id])
|
||||
return items
|
||||
|
||||
|
||||
|
@ -334,7 +365,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "MLS_" + speaker])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -404,7 +435,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
|||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
items.append([text, wav_path, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -418,5 +449,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2].replace(" ", "")
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
|
|
@ -113,7 +113,7 @@ class ActNorm(nn.Module):
|
|||
denom = torch.sum(x_mask, [0, 2])
|
||||
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||||
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||||
v = m_sq - (m ** 2)
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
|
|
|
@ -65,7 +65,7 @@ class WN(torch.nn.Module):
|
|||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
# intermediate layers
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate ** i
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
|
|
|
@ -101,7 +101,7 @@ class Encoder(nn.Module):
|
|||
self.encoder_type = encoder_type
|
||||
# embedding layer
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
# init encoder module
|
||||
if encoder_type.lower() == "rel_pos_transformer":
|
||||
if use_prenet:
|
||||
|
|
|
@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
# relative positional encoding layers
|
||||
if rel_attn_window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else num_heads
|
||||
rel_stddev = self.k_channels ** -0.5
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
|
@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
|
||||
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
|
|
|
@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
|
||||
|
||||
@staticmethod
|
||||
def _make_masks(ilens, olens):
|
||||
|
@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
|
@ -588,13 +587,12 @@ class VitsGeneratorLoss(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
return l
|
||||
return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
waveform_hat,
|
||||
mel_slice,
|
||||
mel_slice_hat,
|
||||
z_p,
|
||||
logs_q,
|
||||
m_p,
|
||||
|
@ -610,8 +608,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
):
|
||||
"""
|
||||
Shapes:
|
||||
- waveform : :math:`[B, 1, T]`
|
||||
- waveform_hat: :math:`[B, 1, T]`
|
||||
- mel_slice : :math:`[B, 1, T]`
|
||||
- mel_slice_hat: :math:`[B, 1, T]`
|
||||
- z_p: :math:`[B, C, T]`
|
||||
- logs_q: :math:`[B, C, T]`
|
||||
- m_p: :math:`[B, C, T]`
|
||||
|
@ -624,23 +622,23 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss = 0.0
|
||||
return_dict = {}
|
||||
z_mask = sequence_mask(z_len).float()
|
||||
# compute mel spectrograms from the waveforms
|
||||
mel = self.stft(waveform)
|
||||
mel_hat = self.stft(waveform_hat)
|
||||
|
||||
# compute losses
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_kl = (
|
||||
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
||||
* self.kl_loss_alpha
|
||||
)
|
||||
loss_feat = (
|
||||
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
)
|
||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||
loss += loss_se
|
||||
loss = loss + loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
@ -665,20 +663,24 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
real_loss = torch.mean((1 - dr) ** 2)
|
||||
fake_loss = torch.mean(dg ** 2)
|
||||
fake_loss = torch.mean(dg**2)
|
||||
loss += real_loss + fake_loss
|
||||
real_losses.append(real_loss.item())
|
||||
fake_losses.append(fake_loss.item())
|
||||
|
||||
return loss, real_losses, fake_losses
|
||||
|
||||
def forward(self, scores_disc_real, scores_disc_fake):
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
||||
scores_real=scores_disc_real, scores_fake=scores_disc_fake
|
||||
)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
|
||||
for i, ldr in enumerate(loss_disc_real):
|
||||
return_dict[f"loss_disc_real_{i}"] = ldr
|
||||
return return_dict
|
||||
|
||||
|
||||
|
@ -740,6 +742,7 @@ class ForwardTTSLoss(nn.Module):
|
|||
alignment_logprob=None,
|
||||
alignment_hard=None,
|
||||
alignment_soft=None,
|
||||
binary_loss_weight=None,
|
||||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
@ -772,7 +775,12 @@ class ForwardTTSLoss(nn.Module):
|
|||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
if binary_loss_weight:
|
||||
return_dict["loss_binary_alignment"] = (
|
||||
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
)
|
||||
else:
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim ** 0.5)
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
|
|
|
@ -6,7 +6,6 @@ from .attentions import init_attn
|
|||
from .common_layers import Linear, Prenet
|
||||
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
class ConvBNBlock(nn.Module):
|
||||
|
|
|
@ -57,7 +57,7 @@ class TextEncoder(nn.Module):
|
|||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
|
||||
if language_emb_dim:
|
||||
hidden_channels += language_emb_dim
|
||||
|
@ -83,6 +83,7 @@ class TextEncoder(nn.Module):
|
|||
- x: :math:`[B, T]`
|
||||
- x_length: :math:`[B]`
|
||||
"""
|
||||
assert x.shape[0] == x_lengths.shape[0]
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
|
||||
# concat the lang emb in embedding chars
|
||||
|
@ -90,7 +91,7 @@ class TextEncoder(nn.Module):
|
|||
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Note:
|
||||
Set `reverse` to True for inference.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
|
@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Note:
|
||||
Set `reverse` to True for inference.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
|
|
|
@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module):
|
|||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
dilation = kernel_size ** i
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||
|
@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||
# posterior encoder - neg log likelihood
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
nll_posterior_encoder = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
)
|
||||
|
||||
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
|
||||
|
@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||
z = torch.flip(z, [1])
|
||||
|
||||
# flow layers - neg log likelihood
|
||||
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
||||
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll_flow_layers + nll_posterior_encoder
|
||||
|
||||
flows = list(reversed(self.flows))
|
||||
|
|
|
@ -1,52 +1,14 @@
|
|||
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
MyModel = find_module("TTS.tts.models", config.base_model.lower())
|
||||
else:
|
||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||
# define set of characters used by the model
|
||||
if config.characters is not None:
|
||||
# set characters from config
|
||||
if hasattr(MyModel, "make_symbols"):
|
||||
symbols = MyModel.make_symbols(config)
|
||||
else:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
|
||||
|
||||
if config.use_phonemes:
|
||||
symbols = phonemes
|
||||
# use default characters and assign them to config
|
||||
config.characters = parse_symbols()
|
||||
# consider special `blank` character if `add_blank` is set True
|
||||
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||
config.num_chars = num_chars
|
||||
# compatibility fix
|
||||
if "model_params" in config:
|
||||
config.model_params.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
config.model_args.num_chars = num_chars
|
||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||
else:
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
model = MyModel.init_from_config(config, samples)
|
||||
return model
|
||||
|
||||
|
||||
# TODO; class registery
|
||||
# def import_models(models_dir, namespace):
|
||||
# for file in os.listdir(models_dir):
|
||||
# path = os.path.join(models_dir, file)
|
||||
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
# model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
# importlib.import_module(namespace + "." + model_name)
|
||||
#
|
||||
#
|
||||
## automatically import any Python files in the models/ directory
|
||||
# models_dir = os.path.dirname(__file__)
|
||||
# import_models(models_dir, "TTS.tts.models")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -100,11 +102,16 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: "AlignTTSConfig",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
self.speaker_manager = speaker_manager
|
||||
self.config = config
|
||||
self.phase = -1
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
|
@ -112,10 +119,6 @@ class AlignTTS(BaseTTS):
|
|||
else config.model_args.length_scale
|
||||
)
|
||||
|
||||
if not self.config.model_args.num_chars:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.model_args.num_chars = num_chars
|
||||
|
||||
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
||||
|
||||
self.embedded_speaker_dim = 0
|
||||
|
@ -382,19 +385,17 @@ class AlignTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -430,3 +431,19 @@ class AlignTTS(BaseTTS):
|
|||
def on_epoch_start(self, trainer):
|
||||
"""Set AlignTTS training phase on epoch start."""
|
||||
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (AlignTTSConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return AlignTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -9,6 +9,8 @@ from torch import nn
|
|||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler
|
|||
class BaseTacotron(BaseTTS):
|
||||
"""Base class shared by Tacotron and Tacotron2"""
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "TacotronConfig",
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields as class attributes
|
||||
for key in config:
|
||||
|
@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS):
|
|||
"""Get the model criterion used in training."""
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Initialize model from config."""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return BaseTacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -9,33 +9,44 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTTS(BaseModel):
|
||||
class BaseTTS(BaseTrainerModel):
|
||||
"""Base `tts` class. Every new `tts` model must inherit this.
|
||||
|
||||
It defines common `tts` specific functions on top of `Model` implementation.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.tokenizer = tokenizer
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
||||
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
@ -44,8 +55,11 @@ class BaseTTS(BaseModel):
|
|||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
config_num_chars = (
|
||||
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
|
||||
)
|
||||
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
|
@ -58,22 +72,6 @@ class BaseTTS(BaseModel):
|
|||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = CharactersConfig(**parse_symbols())
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
return model_characters, config, num_chars
|
||||
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
@ -170,8 +168,8 @@ class BaseTTS(BaseModel):
|
|||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["text"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
text_input = batch["token_id"]
|
||||
text_lengths = batch["token_id_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
|
@ -239,7 +237,7 @@ class BaseTTS(BaseModel):
|
|||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
|
@ -247,8 +245,6 @@ class BaseTTS(BaseModel):
|
|||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
ap = assets["audio_processor"]
|
||||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
|
@ -264,12 +260,8 @@ class BaseTTS(BaseModel):
|
|||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup custom symbols if needed
|
||||
custom_symbols = None
|
||||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
if hasattr(self, "language_manager"):
|
||||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
|
@ -279,74 +271,40 @@ class BaseTTS(BaseModel):
|
|||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
custom_symbols=custom_symbols,
|
||||
add_blank=config["add_blank"],
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
min_text_len=config.min_text_len,
|
||||
max_text_len=config.max_text_len,
|
||||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
precompute_num_workers=config.precompute_num_workers,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
tokenizer=self.tokenizer,
|
||||
start_by_longest=config.start_by_longest,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
dataset.items = self.train_data_items
|
||||
else:
|
||||
# precompute phonemes for precise estimate of sequence lengths.
|
||||
# otherwise `dataset.sort_items()` uses raw text lengths
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
|
||||
# TODO: find a more efficient solution
|
||||
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
||||
if is_eval:
|
||||
self.eval_data_items = dataset.items
|
||||
else:
|
||||
self.train_data_items = dataset.items
|
||||
|
||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||
|
||||
# compute pitch frames and write to files.
|
||||
if config.compute_f0 and rank in [None, 0]:
|
||||
if not os.path.exists(config.f0_cache_path):
|
||||
dataset.pitch_extractor.compute_pitch(
|
||||
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||
)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# load pitch stats computed above by all the workers
|
||||
if config.compute_f0:
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# Weighted samplers
|
||||
# TODO: make this DDP amenable
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
|
@ -357,17 +315,17 @@ class BaseTTS(BaseModel):
|
|||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
sampler = get_language_weighted_sampler(dataset.samples)
|
||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||
sampler = get_speaker_weighted_sampler(dataset.samples)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
@ -403,7 +361,6 @@ class BaseTTS(BaseModel):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -415,17 +372,15 @@ class BaseTTS(BaseModel):
|
|||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -14,7 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -170,17 +171,22 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
self._set_model_args(config)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.max_duration = self.args.max_duration
|
||||
self.use_aligner = self.args.use_aligner
|
||||
self.use_pitch = self.args.use_pitch
|
||||
self.use_binary_alignment_loss = False
|
||||
self.binary_loss_weight = 0.0
|
||||
|
||||
self.length_scale = (
|
||||
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||
|
@ -255,7 +261,7 @@ class ForwardTTS(BaseTTS):
|
|||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels)
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
|
@ -638,8 +644,9 @@ class ForwardTTS(BaseTTS):
|
|||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||
input_lens=text_lengths,
|
||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||
alignment_soft=outputs["alignment_soft"],
|
||||
alignment_hard=outputs["alignment_mas"],
|
||||
binary_loss_weight=self.binary_loss_weight,
|
||||
)
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
|
@ -666,17 +673,12 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
# plot pitch figures
|
||||
if self.args.use_pitch:
|
||||
pitch = batch["pitch"]
|
||||
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||
)
|
||||
pitch = pitch[0, 0].data.cpu().numpy()
|
||||
# TODO: denormalize before plotting
|
||||
pitch = abs(pitch)
|
||||
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
||||
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
|
||||
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||
pitch_figures = {
|
||||
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
||||
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
||||
}
|
||||
figures.update(pitch_figures)
|
||||
|
||||
|
@ -692,19 +694,17 @@ class ForwardTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -721,6 +721,21 @@ class ForwardTTS(BaseTTS):
|
|||
return ForwardTTSLoss(self.config)
|
||||
|
||||
def on_train_step_start(self, trainer):
|
||||
"""Enable binary alignment loss when needed"""
|
||||
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||
self.use_binary_alignment_loss = True
|
||||
"""Schedule binary loss weight."""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (ForwardTTSConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -39,18 +40,31 @@ class GlowTTS(BaseTTS):
|
|||
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
Init only model layers.
|
||||
|
||||
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig(num_chars=2)
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
Fully init a model ready for action. All the class attributes and class members
|
||||
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
|
||||
|
||||
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig()
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
>>> model = GlowTTS.init_from_config(config, verbose=False)
|
||||
"""
|
||||
|
||||
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: GlowTTSConfig,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
|
@ -58,7 +72,6 @@ class GlowTTS(BaseTTS):
|
|||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
_, self.config, self.num_chars = self.get_characters(config)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# init multi-speaker layers if necessary
|
||||
|
@ -94,25 +107,25 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
||||
vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets
|
||||
speaker embedding vector dimension to the d-vector dimension from the config.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
if self.speaker_manager is not None:
|
||||
assert (
|
||||
config.d_vector_dim == self.speaker_manager.d_vector_dim
|
||||
), " [!] d-vector dimension mismatch b/w config and speaker manager."
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
|
@ -170,6 +183,8 @@ class GlowTTS(BaseTTS):
|
|||
if g is not None:
|
||||
if hasattr(self, "emb_g"):
|
||||
# use speaker embedding layer
|
||||
if not g.size(): # if is a scalar
|
||||
g = g.unsqueeze(0) # unsqueeze
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
# use d-vector
|
||||
|
@ -180,12 +195,33 @@ class GlowTTS(BaseTTS):
|
|||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts::math:`B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths::math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input text sequence ids. :math:`[B, T_en]`
|
||||
|
||||
x_lengths (torch.Tensor):
|
||||
Lengths of input text sequences. :math:`[B]`
|
||||
|
||||
y (torch.Tensor):
|
||||
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
|
||||
|
||||
y_lengths (torch.Tensor):
|
||||
Lengths of target mel-spectrogram frames. :math:`[B]`
|
||||
|
||||
aux_input (Dict):
|
||||
Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model.
|
||||
:math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding
|
||||
layer. :math:`B`
|
||||
|
||||
Returns:
|
||||
Dict:
|
||||
- z: :math: `[B, T_de, C]`
|
||||
- logdet: :math:`B`
|
||||
- y_mean: :math:`[B, T_de, C]`
|
||||
- y_log_scale: :math:`[B, T_de, C]`
|
||||
- alignments: :math:`[B, T_en, T_de]`
|
||||
- durations_log: :math:`[B, T_en, 1]`
|
||||
- total_durations_log: :math:`[B, T_en, 1]`
|
||||
"""
|
||||
# [B, T, C] -> [B, C, T]
|
||||
y = y.transpose(1, 2)
|
||||
|
@ -206,9 +242,9 @@ class GlowTTS(BaseTTS):
|
|||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
|
@ -255,9 +291,9 @@ class GlowTTS(BaseTTS):
|
|||
# find the alignment path between z and encoder output
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
|
@ -422,20 +458,18 @@ class GlowTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
|
@ -446,7 +480,6 @@ class GlowTTS(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -461,18 +494,16 @@ class GlowTTS(BaseTTS):
|
|||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs["outputs"]["model_outputs"], ap, output_fig=False
|
||||
outputs["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
@ -499,7 +530,8 @@ class GlowTTS(BaseTTS):
|
|||
self.store_inverse()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return GlowTTSLoss()
|
||||
|
@ -507,3 +539,20 @@ class GlowTTS(BaseTTS):
|
|||
def on_train_step_start(self, trainer):
|
||||
"""Decide on every training step wheter enable/disable data depended initialization."""
|
||||
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (VitsConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
verbose (bool): If True, print init messages. Defaults to True.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config, verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return GlowTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
|
@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -24,12 +26,15 @@ class Tacotron(BaseTacotron):
|
|||
a multi-speaker model. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "TacotronConfig",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
chars, self.config, _ = self.get_characters(config)
|
||||
config.num_chars = self.num_chars = len(chars)
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
|
@ -302,16 +307,30 @@ class Tacotron(BaseTacotron):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (TacotronConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return Tacotron(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
|
@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron):
|
|||
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "Tacotron2Config",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
chars, self.config, _ = self.get_characters(config)
|
||||
config.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# pass all config fields to `self`
|
||||
|
@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron):
|
|||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
"""Log training progress."""
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (Tacotron2Config): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(new_config, samples)
|
||||
return Tacotron2(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,20 +0,0 @@
|
|||
## Utilities to Convert Models to Tensorflow2
|
||||
Here there are experimental utilities to convert trained Torch models to Tensorflow (2.2>=).
|
||||
|
||||
Converting Torch models to TF enables all the TF toolkit to be used for better deployment and device specific optimizations.
|
||||
|
||||
Note that we do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome.
|
||||
|
||||
To see how you can use TF model at inference, check the notebook.
|
||||
|
||||
This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own.
|
||||
|
||||
|
||||
### Converting a Model
|
||||
- Run ```convert_tacotron2_torch_to_tf.py --torch_model_path /path/to/torch/model.pth.tar --config_path /path/to/model/config.json --output_path /path/to/output/tf/model``` with the right arguments.
|
||||
|
||||
### Known issues ans limitations
|
||||
- We use a custom model load/save mechanism which enables us to store model related information with models weights. (Similar to Torch). However, it is prone to random errors.
|
||||
- Current TF model implementation is slightly slower than Torch model. Hopefully, it'll get better with improving TF support for eager mode and ```tf.function```.
|
||||
- TF implementation of Tacotron2 only supports regular Tacotron2 as in the paper.
|
||||
- You can only convert models trained after TF model implementation since model layers has been updated in Torch model.
|
|
@ -1,301 +0,0 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# from tensorflow_addons.seq2seq import BahdanauAttention
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
class Linear(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
return self.activation(self.linear_layer(x))
|
||||
|
||||
|
||||
class LinearBN(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
|
||||
self.batch_normalization = keras.layers.BatchNormalization(
|
||||
axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization"
|
||||
)
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
out = self.linear_layer(x)
|
||||
out = self.batch_normalization(out, training=training)
|
||||
return self.activation(out)
|
||||
|
||||
|
||||
class Prenet(keras.layers.Layer):
|
||||
def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.linear_layers = []
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers += [
|
||||
LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
|
||||
]
|
||||
elif prenet_type == "original":
|
||||
self.linear_layers += [
|
||||
Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
|
||||
]
|
||||
else:
|
||||
raise RuntimeError(" [!] Unknown prenet type.")
|
||||
if prenet_dropout:
|
||||
self.dropout = keras.layers.Dropout(rate=0.5)
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = self.dropout(linear(x), training=training)
|
||||
else:
|
||||
x = linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def _sigmoid_norm(score):
|
||||
attn_weights = tf.nn.sigmoid(score)
|
||||
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
|
||||
return attn_weights
|
||||
|
||||
|
||||
class Attention(keras.layers.Layer):
|
||||
"""TODO: implement forward_attention
|
||||
TODO: location sensitive attention
|
||||
TODO: implement attention windowing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_dim,
|
||||
use_loc_attn,
|
||||
loc_attn_n_filters,
|
||||
loc_attn_kernel_size,
|
||||
use_windowing,
|
||||
norm,
|
||||
use_forward_attn,
|
||||
use_trans_agent,
|
||||
use_forward_attn_mask,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.use_loc_attn = use_loc_attn
|
||||
self.loc_attn_n_filters = loc_attn_n_filters
|
||||
self.loc_attn_kernel_size = loc_attn_kernel_size
|
||||
self.use_windowing = use_windowing
|
||||
self.norm = norm
|
||||
self.use_forward_attn = use_forward_attn
|
||||
self.use_trans_agent = use_trans_agent
|
||||
self.use_forward_attn_mask = use_forward_attn_mask
|
||||
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer")
|
||||
self.inputs_layer = tf.keras.layers.Dense(
|
||||
attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer"
|
||||
)
|
||||
self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer")
|
||||
if use_loc_attn:
|
||||
self.location_conv1d = keras.layers.Conv1D(
|
||||
filters=loc_attn_n_filters,
|
||||
kernel_size=loc_attn_kernel_size,
|
||||
padding="same",
|
||||
use_bias=False,
|
||||
name="location_layer/location_conv1d",
|
||||
)
|
||||
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense")
|
||||
if norm == "softmax":
|
||||
self.norm_func = tf.nn.softmax
|
||||
elif norm == "sigmoid":
|
||||
self.norm_func = _sigmoid_norm
|
||||
else:
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
|
||||
def init_states(self, batch_size, value_length):
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
attention_cum = tf.zeros([batch_size, value_length])
|
||||
attention_old = tf.zeros([batch_size, value_length])
|
||||
states = [attention_cum, attention_old]
|
||||
if self.use_forward_attn:
|
||||
alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1)
|
||||
states.append(alpha)
|
||||
return tuple(states)
|
||||
|
||||
def process_values(self, values):
|
||||
"""cache values for decoder iterations"""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.processed_values = self.inputs_layer(values)
|
||||
self.values = values
|
||||
|
||||
def get_loc_attn(self, query, states):
|
||||
"""compute location attention, query layer and
|
||||
unnorm. attention weights"""
|
||||
attention_cum, attention_old = states[:2]
|
||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
|
||||
score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def get_attn(self, query):
|
||||
"""compute query layer and unnormalized attention weights"""
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def apply_score_masking(self, score, mask): # pylint: disable=no-self-use
|
||||
"""ignore sequence paddings"""
|
||||
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# Bias so padding positions do not contribute to attention distribution.
|
||||
score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
return score
|
||||
|
||||
def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use
|
||||
# forward attention
|
||||
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
|
||||
# compute transition potentials
|
||||
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
||||
# renormalize attention weights
|
||||
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
||||
return new_alpha
|
||||
|
||||
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
||||
states = []
|
||||
if self.use_loc_attn:
|
||||
states = [old_states[0] + scores_norm, attn_weights]
|
||||
if self.use_forward_attn:
|
||||
states.append(new_alpha)
|
||||
return tuple(states)
|
||||
|
||||
def call(self, query, states):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D
|
||||
"""
|
||||
if self.use_loc_attn:
|
||||
score, _ = self.get_loc_attn(query, states)
|
||||
else:
|
||||
score, _ = self.get_attn(query)
|
||||
|
||||
# TODO: masking
|
||||
# if mask is not None:
|
||||
# self.apply_score_masking(score, mask)
|
||||
# attn_weights shape == (batch_size, max_length, 1)
|
||||
|
||||
# normalize attention scores
|
||||
scores_norm = self.norm_func(score)
|
||||
attn_weights = scores_norm
|
||||
|
||||
# apply forward attention
|
||||
new_alpha = None
|
||||
if self.use_forward_attn:
|
||||
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
||||
attn_weights = new_alpha
|
||||
|
||||
# update states tuple
|
||||
# states = (cum_attn_weights, attn_weights, new_alpha)
|
||||
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||
|
||||
# context_vector shape after sum == (batch_size, hidden_size)
|
||||
context_vector = tf.matmul(
|
||||
tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False
|
||||
)
|
||||
context_vector = tf.squeeze(context_vector, axis=1)
|
||||
return context_vector, attn_weights, states
|
||||
|
||||
|
||||
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
|
||||
# dtype = processed_query.dtype
|
||||
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
|
||||
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
|
||||
|
||||
|
||||
# class LocationSensitiveAttention(BahdanauAttention):
|
||||
# def __init__(self,
|
||||
# units,
|
||||
# memory=None,
|
||||
# memory_sequence_length=None,
|
||||
# normalize=False,
|
||||
# probability_fn="softmax",
|
||||
# kernel_initializer="glorot_uniform",
|
||||
# dtype=None,
|
||||
# name="LocationSensitiveAttention",
|
||||
# location_attention_filters=32,
|
||||
# location_attention_kernel_size=31):
|
||||
|
||||
# super( self).__init__(units=units,
|
||||
# memory=memory,
|
||||
# memory_sequence_length=memory_sequence_length,
|
||||
# normalize=normalize,
|
||||
# probability_fn='softmax', ## parent module default
|
||||
# kernel_initializer=kernel_initializer,
|
||||
# dtype=dtype,
|
||||
# name=name)
|
||||
# if probability_fn == 'sigmoid':
|
||||
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
|
||||
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
|
||||
# self.location_dense = keras.layers.Dense(units, use_bias=False)
|
||||
# # self.v = keras.layers.Dense(1, use_bias=True)
|
||||
|
||||
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
|
||||
# processed_query = tf.expand_dims(processed_query, 1)
|
||||
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
|
||||
|
||||
# def _location_sensitive(self, alignment_cum, alignment_old):
|
||||
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
|
||||
# return self.location_dense(self.location_conv(alignment_cat))
|
||||
|
||||
# def _sigmoid_normalization(self, score):
|
||||
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
|
||||
|
||||
# # def _apply_masking(self, score, mask):
|
||||
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# # # Bias so padding positions do not contribute to attention distribution.
|
||||
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
# # return score
|
||||
|
||||
# def _calculate_attention(self, query, state):
|
||||
# alignment_cum, alignment_old = state[:2]
|
||||
# processed_query = self.query_layer(
|
||||
# query) if self.query_layer else query
|
||||
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
|
||||
# score = self._location_sensitive_score(
|
||||
# processed_query,
|
||||
# self.keys,
|
||||
# processed_loc)
|
||||
# alignment = self.probability_fn(score, state)
|
||||
# alignment_cum = alignment_cum + alignment
|
||||
# state[0] = alignment_cum
|
||||
# state[1] = alignment
|
||||
# return alignment, state
|
||||
|
||||
# def compute_context(self, alignments):
|
||||
# expanded_alignments = tf.expand_dims(alignments, 1)
|
||||
# context = tf.matmul(expanded_alignments, self.values)
|
||||
# context = tf.squeeze(context, [1])
|
||||
# return context
|
||||
|
||||
# # def call(self, query, state):
|
||||
# # alignment, next_state = self._calculate_attention(query, state)
|
||||
# # return alignment, next_state
|
|
@ -1,322 +0,0 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet
|
||||
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
class ConvBNBlock(keras.layers.Layer):
|
||||
def __init__(self, filters, kernel_size, activation, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d")
|
||||
self.batch_normalization = keras.layers.BatchNormalization(
|
||||
axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization"
|
||||
)
|
||||
self.dropout = keras.layers.Dropout(rate=0.5, name="dropout")
|
||||
self.activation = keras.layers.Activation(activation, name="activation")
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = self.convolution1d(x)
|
||||
o = self.batch_normalization(o, training=training)
|
||||
o = self.activation(o)
|
||||
o = self.dropout(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Postnet(keras.layers.Layer):
|
||||
def __init__(self, output_filters, num_convs, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0"))
|
||||
for idx in range(1, num_convs - 1):
|
||||
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}"))
|
||||
self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}"))
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(keras.layers.Layer):
|
||||
def __init__(self, output_input_dim, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
for idx in range(3):
|
||||
self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}"))
|
||||
self.lstm = keras.layers.Bidirectional(
|
||||
keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm"
|
||||
)
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
o = self.lstm(o)
|
||||
return o
|
||||
|
||||
|
||||
class Decoder(keras.layers.Layer):
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(
|
||||
self,
|
||||
frame_dim,
|
||||
r,
|
||||
attn_type,
|
||||
use_attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
use_forward_attn,
|
||||
use_trans_agent,
|
||||
use_forward_attn_mask,
|
||||
use_location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
speaker_emb_dim,
|
||||
enable_tflite,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||
self.r = tf.constant(r, dtype=tf.int32)
|
||||
self.output_dim = r * self.frame_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
# layer constants
|
||||
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
||||
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
||||
|
||||
# model dimensions
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.attn_dim = 128
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet")
|
||||
self.attention_rnn = keras.layers.LSTMCell(
|
||||
self.query_dim,
|
||||
use_bias=True,
|
||||
name="attention_rnn",
|
||||
)
|
||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
|
||||
# TODO: implement other attn options
|
||||
self.attention = Attention(
|
||||
attn_dim=self.attn_dim,
|
||||
use_loc_attn=True,
|
||||
loc_attn_n_filters=32,
|
||||
loc_attn_kernel_size=31,
|
||||
use_windowing=False,
|
||||
norm=attn_norm,
|
||||
use_forward_attn=use_forward_attn,
|
||||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name="attention",
|
||||
)
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn")
|
||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer")
|
||||
self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer")
|
||||
|
||||
def set_max_decoder_steps(self, new_max_steps):
|
||||
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
|
||||
|
||||
def set_r(self, new_r):
|
||||
self.r = tf.constant(new_r, dtype=tf.int32)
|
||||
self.output_dim = self.frame_dim * new_r
|
||||
|
||||
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
||||
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
||||
zero_context = tf.zeros([batch_size, memory_dim])
|
||||
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
attention_states = self.attention.init_states(batch_size, memory_length)
|
||||
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
|
||||
|
||||
def step(self, prenet_next, states, memory_seq_length=None, training=None):
|
||||
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
|
||||
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
|
||||
attention_rnn_output, attention_rnn_state = self.attention_rnn(
|
||||
attention_rnn_input, attention_rnn_state, training=training
|
||||
)
|
||||
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
|
||||
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
|
||||
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
|
||||
decoder_rnn_output, decoder_rnn_state = self.decoder_rnn(
|
||||
decoder_rnn_input, decoder_rnn_state, training=training
|
||||
)
|
||||
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
|
||||
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
|
||||
output_frame = self.linear_projection(linear_projection_input, training=training)
|
||||
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
|
||||
stopnet_output = self.stopnet(stopnet_input, training=training)
|
||||
output_frame = output_frame[:, : self.r * self.frame_dim]
|
||||
states = (
|
||||
output_frame[:, self.frame_dim * (self.r - 1) :],
|
||||
context,
|
||||
attention_rnn_state,
|
||||
decoder_rnn_state,
|
||||
attention_states,
|
||||
)
|
||||
return output_frame, stopnet_output, states, attention
|
||||
|
||||
def decode(self, memory, states, frames, memory_seq_length=None):
|
||||
B, _, _ = shape_list(memory)
|
||||
num_iter = shape_list(frames)[1] // self.r
|
||||
# init states
|
||||
frame_zero = tf.expand_dims(states[0], 1)
|
||||
frames = tf.concat([frame_zero, frames], axis=1)
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
prenet_output = self.prenet(frames, training=True)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
|
||||
prenet_next = prenet_output[:, step]
|
||||
output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
|
||||
|
||||
_, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
|
||||
lambda *arg: True,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=num_iter,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0, 2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference(self, memory, states):
|
||||
B, _, _ = shape_list(memory)
|
||||
# init states
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, attention = self.step(prenet_next, states, None, training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
_, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop(
|
||||
cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0, 2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference_tflite(self, memory, states):
|
||||
"""Inference with TF-Lite compatibility. It assumes
|
||||
batch_size is 1"""
|
||||
# init states
|
||||
# dynamic_shape is not supported in TFLite
|
||||
outputs = tf.TensorArray(
|
||||
dtype=tf.float32,
|
||||
size=self.max_decoder_steps,
|
||||
element_shape=tf.TensorShape([self.output_dim]),
|
||||
clear_after_read=False,
|
||||
dynamic_size=False,
|
||||
)
|
||||
# stop_flags = tf.TensorArray(dtype=tf.bool,
|
||||
# size=self.max_decoder_steps,
|
||||
# element_shape=tf.TensorShape(
|
||||
# []),
|
||||
# clear_after_read=False,
|
||||
# dynamic_size=False)
|
||||
attentions = ()
|
||||
stop_tokens = ()
|
||||
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, _ = self.step(prenet_next, states, None, training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
|
||||
|
||||
outputs = outputs.write(step, tf.reshape(output, [-1]))
|
||||
return step + 1, memory, states, outputs, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
step_count, memory, states, outputs, stop_flag = tf.while_loop(
|
||||
cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
|
||||
outputs = tf.expand_dims(outputs, axis=[0])
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||
if training:
|
||||
return self.decode(memory, states, frames, memory_seq_length)
|
||||
if self.enable_tflite:
|
||||
return self.decode_inference_tflite(memory, states)
|
||||
return self.decode_inference(memory, states)
|
|
@ -1,116 +0,0 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors, abstract-method
|
||||
class Tacotron2(keras.models.Model):
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
out_channels=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
attn_K=4,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
enable_tflite=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.out_channels = out_channels
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.num_speakers = num_speakers
|
||||
self.speaker_embed_dim = 256
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding")
|
||||
self.encoder = Encoder(512, name="encoder")
|
||||
# TODO: most of the decoder args have no use at the momment
|
||||
self.decoder = Decoder(
|
||||
decoder_output_dim,
|
||||
r,
|
||||
attn_type=attn_type,
|
||||
use_attn_win=attn_win,
|
||||
attn_norm=attn_norm,
|
||||
prenet_type=prenet_type,
|
||||
prenet_dropout=prenet_dropout,
|
||||
use_forward_attn=forward_attn,
|
||||
use_trans_agent=trans_agent,
|
||||
use_forward_attn_mask=forward_attn_mask,
|
||||
use_location_attn=location_attn,
|
||||
attn_K=attn_K,
|
||||
separate_stopnet=separate_stopnet,
|
||||
speaker_emb_dim=self.speaker_embed_dim,
|
||||
name="decoder",
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
self.postnet = Postnet(out_channels, 5, name="postnet")
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
if training:
|
||||
return self.training(characters, text_lengths, frames)
|
||||
if not training:
|
||||
return self.inference(characters)
|
||||
raise RuntimeError(" [!] Set model training mode True or False")
|
||||
|
||||
def training(self, characters, text_lengths, frames):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=True)
|
||||
encoder_output = self.encoder(embedding_vectors, training=True)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(
|
||||
encoder_output, decoder_states, frames, text_lengths, training=True
|
||||
)
|
||||
postnet_frames = self.postnet(decoder_frames, training=True)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def inference(self, characters):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=False)
|
||||
encoder_output = self.encoder(embedding_vectors, training=False)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
||||
postnet_frames = self.postnet(decoder_frames, training=False)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
@tf.function(
|
||||
experimental_relax_shapes=True,
|
||||
input_signature=[
|
||||
tf.TensorSpec([1, None], dtype=tf.int32),
|
||||
],
|
||||
)
|
||||
def inference_tflite(self, characters):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=False)
|
||||
encoder_output = self.encoder(embedding_vectors, training=False)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
||||
postnet_frames = self.postnet(decoder_frames, training=False)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def build_inference(
|
||||
self,
|
||||
):
|
||||
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
|
||||
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg
|
||||
self(input_ids)
|
|
@ -1,87 +0,0 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
def tf_create_dummy_inputs():
|
||||
"""Create dummy inputs for TF Tacotron2 model"""
|
||||
batch_size = 4
|
||||
max_input_length = 32
|
||||
max_mel_length = 128
|
||||
pad = 1
|
||||
n_chars = 24
|
||||
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
|
||||
input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size])
|
||||
input_lengths[-1] = max_input_length
|
||||
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
|
||||
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
|
||||
mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size])
|
||||
mel_lengths[-1] = max_mel_length
|
||||
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
|
||||
return input_ids, input_lengths, mel_outputs, mel_lengths
|
||||
|
||||
|
||||
def compare_torch_tf(torch_tensor, tf_tensor):
|
||||
"""Compute the average absolute difference b/w torch and tf tensors"""
|
||||
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
|
||||
|
||||
|
||||
def convert_tf_name(tf_name):
|
||||
"""Convert certain patterns in TF layer names to Torch patterns"""
|
||||
tf_name_tmp = tf_name
|
||||
tf_name_tmp = tf_name_tmp.replace(":0", "")
|
||||
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0")
|
||||
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1")
|
||||
tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh")
|
||||
tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight")
|
||||
tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight")
|
||||
tf_name_tmp = tf_name_tmp.replace("/beta", "/bias")
|
||||
tf_name_tmp = tf_name_tmp.replace("/", ".")
|
||||
return tf_name_tmp
|
||||
|
||||
|
||||
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
|
||||
"""Transfer weigths from torch state_dict to TF variables"""
|
||||
print(" > Passing weights from Torch to TF ...")
|
||||
for tf_var in tf_vars:
|
||||
torch_var_name = var_map_dict[tf_var.name]
|
||||
print(f" | > {tf_var.name} <-- {torch_var_name}")
|
||||
# if tuple, it is a bias variable
|
||||
if not isinstance(torch_var_name, tuple):
|
||||
torch_layer_name = ".".join(torch_var_name.split(".")[-2:])
|
||||
torch_weight = state_dict[torch_var_name]
|
||||
if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name:
|
||||
# out_dim, in_dim, filter -> filter, in_dim, out_dim
|
||||
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
|
||||
elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
# if variable is for bidirectional lstm and it is a bias vector there
|
||||
# needs to be pre-defined two matching torch bias vectors
|
||||
elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif "rnn" in tf_var.name and "kernel" in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
elif "rnn" in tf_var.name and "bias" in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif "linear_layer" in torch_layer_name and "weight" in torch_var_name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
else:
|
||||
numpy_weight = torch_weight.detach().cpu().numpy()
|
||||
assert np.all(
|
||||
tf_var.shape == numpy_weight.shape
|
||||
), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||
tf.keras.backend.set_value(tf_var, numpy_weight)
|
||||
return tf_vars
|
||||
|
||||
|
||||
def load_tf_vars(model_tf, tf_vars):
|
||||
for tf_var in tf_vars:
|
||||
model_tf.get_layer(tf_var.name).set_weights(tf_var)
|
||||
return model_tf
|
|
@ -1,105 +0,0 @@
|
|||
import datetime
|
||||
import importlib
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
state = {
|
||||
"model": model.weights,
|
||||
"optimizer": optimizer,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
try:
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
except KeyError:
|
||||
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
||||
layer_name = f"{class_name}/{layer_name}"
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if "r" in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint["r"])
|
||||
return model
|
||||
|
||||
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = np.empty([0, max_len], dtype=np.int8)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_range_expand = seq_range_expand.type_as(sequence_length)
|
||||
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
# @tf.custom_gradient
|
||||
def check_gradient(x, grad_clip):
|
||||
x_normed = tf.clip_by_norm(x, grad_clip)
|
||||
grad_norm = tf.norm(grad_clip)
|
||||
return x_normed, grad_norm
|
||||
|
||||
|
||||
def count_parameters(model, c):
|
||||
try:
|
||||
return model.count_params()
|
||||
except RuntimeError:
|
||||
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
|
||||
input_lengths = np.random.randint(100, 129, (8,))
|
||||
input_lengths[-1] = 128
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
|
||||
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
|
||||
mel_spec = tf.convert_to_tensor(mel_spec)
|
||||
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
|
||||
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
|
||||
return model.count_params()
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
raise NotImplementedError(" [!] Tacotron model is not ready.")
|
||||
# tacotron2
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
out_channels=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
return model
|
|
@ -1,45 +0,0 @@
|
|||
import datetime
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
state = {
|
||||
"model": model.weights,
|
||||
"optimizer": optimizer,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
try:
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
except KeyError:
|
||||
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
||||
layer_name = f"{class_name}/{layer_name}"
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if "r" in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint["r"])
|
||||
return model
|
||||
|
||||
|
||||
def load_tflite_model(tflite_path):
|
||||
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||
tflite_model.allocate_tensors()
|
||||
return tflite_model
|
|
@ -1,8 +0,0 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def shape_list(x):
|
||||
"""Deal with dynamic shape in tensorflow cleanly."""
|
||||
static = x.shape.as_list()
|
||||
dynamic = tf.shape(x)
|
||||
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
|
@ -1,27 +0,0 @@
|
|||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True):
|
||||
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
|
||||
provided, else return TFLite model."""
|
||||
|
||||
concrete_function = model.inference_tflite.get_concrete_function()
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
|
||||
converter.experimental_new_converter = experimental_converter
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
|
||||
tflite_model = converter.convert()
|
||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
||||
|
||||
def load_tflite_model(tflite_path):
|
||||
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||
tflite_model.allocate_tensors()
|
||||
return tflite_model
|
|
@ -57,40 +57,65 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
return mask
|
||||
|
||||
|
||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
|
||||
"""Segment each sample in a batch based on the provided segment indices
|
||||
|
||||
Args:
|
||||
x (torch.tensor): Input tensor.
|
||||
segment_indices (torch.tensor): Segment indices.
|
||||
segment_size (int): Expected output segment size.
|
||||
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
|
||||
"""
|
||||
# pad the input tensor if it is shorter than the segment size
|
||||
if pad_short and x.shape[-1] < segment_size:
|
||||
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
|
||||
|
||||
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||
|
||||
for i in range(x.size(0)):
|
||||
index_start = segment_indices[i]
|
||||
index_end = index_start + segment_size
|
||||
segments[i] = x[i, :, index_start:index_end]
|
||||
x_i = x[i]
|
||||
if pad_short and index_end > x.size(2):
|
||||
# pad the sample if it is shorter than the segment size
|
||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||
segments[i] = x_i[:, index_start:index_end]
|
||||
return segments
|
||||
|
||||
|
||||
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
|
||||
def rand_segments(
|
||||
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
|
||||
):
|
||||
"""Create random segments based on the input lengths.
|
||||
|
||||
Args:
|
||||
x (torch.tensor): Input tensor.
|
||||
x_lengths (torch.tensor): Input lengths.
|
||||
segment_size (int): Expected output segment size.
|
||||
let_short_samples (bool): Allow shorter samples than the segment size.
|
||||
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_lengths: :math:`[B]`
|
||||
"""
|
||||
_x_lenghts = x_lengths.clone()
|
||||
B, _, T = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = T
|
||||
max_idxs = x_lengths - segment_size + 1
|
||||
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||
if pad_short:
|
||||
if T < segment_size:
|
||||
x = torch.nn.functional.pad(x, (0, segment_size - T))
|
||||
T = segment_size
|
||||
if _x_lenghts is None:
|
||||
_x_lenghts = T
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
if let_short_samples:
|
||||
_x_lenghts[len_diff < 0] = segment_size
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
else:
|
||||
assert all(
|
||||
len_diff > 0
|
||||
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||
ret = segment(x, segment_indices, segment_size)
|
||||
return ret, segment_indices
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import check_config_and_model_args
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
|
@ -98,6 +100,20 @@ class LanguageManager:
|
|||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit) -> "LanguageManager":
|
||||
"""Initialize the language manager from a Coqpit config.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Coqpit config.
|
||||
"""
|
||||
language_manager = None
|
||||
if check_config_and_model_args(config, "use_language_embedding", True):
|
||||
if config.get("language_ids_file", None):
|
||||
language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
language_manager = LanguageManager(config=config)
|
||||
return language_manager
|
||||
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the language_ids.json under the given path or the above it.
|
||||
|
@ -113,7 +129,7 @@ def _set_file_path(path):
|
|||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
language_names = np.array([item["language"] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -118,7 +118,7 @@ class SpeakerManager:
|
|||
Returns:
|
||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
||||
"""
|
||||
speakers = sorted({item[2] for item in items})
|
||||
speakers = sorted({item["speaker_name"] for item in items})
|
||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
@ -318,6 +318,42 @@ class SpeakerManager:
|
|||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
|
||||
"""Initialize a speaker manager from config
|
||||
|
||||
Args:
|
||||
config (Coqpit): Config object.
|
||||
samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
SpeakerEncoder: Speaker encoder object.
|
||||
"""
|
||||
speaker_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False):
|
||||
if samples:
|
||||
speaker_manager = SpeakerManager(data_items=samples)
|
||||
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
|
||||
speaker_manager = SpeakerManager(
|
||||
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
|
||||
)
|
||||
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
|
||||
speaker_manager = SpeakerManager(
|
||||
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None)
|
||||
)
|
||||
|
||||
if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False):
|
||||
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
|
||||
speaker_manager = SpeakerManager(
|
||||
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
|
||||
)
|
||||
if get_from_config_or_model_args_with_default(config, "d_vector_file", None):
|
||||
speaker_manager = SpeakerManager(
|
||||
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None)
|
||||
)
|
||||
return speaker_manager
|
||||
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the speakers.json under the given path or the above it.
|
||||
|
@ -414,7 +450,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
speaker_names = np.array([item[2] for item in items])
|
||||
speaker_names = np.array([item["speaker_name"] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.autograd import Variable
|
|||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
|
@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
|
|
|
@ -1,46 +1,9 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .text import phoneme_to_sequence, text_to_sequence
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable
|
||||
if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
language if language else CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
||||
custom_symbols=custom_symbols,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
else:
|
||||
seq = np.asarray(
|
||||
text_to_sequence(
|
||||
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
return seq
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
if np_array is None:
|
||||
|
@ -51,13 +14,6 @@ def numpy_to_torch(np_array, dtype, cuda=False):
|
|||
return tensor
|
||||
|
||||
|
||||
def numpy_to_tf(np_array, dtype):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
if cuda:
|
||||
|
@ -103,53 +59,6 @@ def run_model_torch(
|
|||
return outputs
|
||||
|
||||
|
||||
def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TF")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||
# TODO: handle multispeaker case
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||
# get input and output details
|
||||
input_details = model.get_input_details()
|
||||
output_details = model.get_output_details()
|
||||
# reshape input tensor for the new input shape
|
||||
model.resize_tensor_input(input_details[0]["index"], inputs.shape)
|
||||
model.allocate_tensors()
|
||||
detail = input_details[0]
|
||||
# input_shape = detail['shape']
|
||||
model.set_tensor(detail["index"], inputs)
|
||||
# run the model
|
||||
model.invoke()
|
||||
# collect outputs
|
||||
decoder_output = model.get_tensor(output_details[0]["index"])
|
||||
postnet_output = model.get_tensor(output_details[1]["index"])
|
||||
# tflite model only returns feature frames
|
||||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
||||
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].numpy()
|
||||
decoder_output = decoder_output[0].numpy()
|
||||
alignment = alignments[0].numpy()
|
||||
stop_tokens = stop_tokens[0].numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_tflite(postnet_output, decoder_output):
|
||||
postnet_output = postnet_output[0]
|
||||
decoder_output = decoder_output[0]
|
||||
return postnet_output, decoder_output
|
||||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
return wav[: ap.find_endpoint(wav)]
|
||||
|
||||
|
@ -204,16 +113,12 @@ def synthesis(
|
|||
text,
|
||||
CONFIG,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
enable_eos_bos_chars=False, # pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
language_id=None,
|
||||
language_name=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
the vocoder model.
|
||||
|
@ -231,9 +136,6 @@ def synthesis(
|
|||
use_cuda (bool):
|
||||
Enable/disable CUDA.
|
||||
|
||||
ap (TTS.tts.utils.audio.AudioProcessor):
|
||||
The audio processor for extracting features and pre/post-processing audio.
|
||||
|
||||
speaker_id (int):
|
||||
Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.
|
||||
|
||||
|
@ -251,74 +153,51 @@ def synthesis(
|
|||
|
||||
language_id (int):
|
||||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||
|
||||
language_name (str):
|
||||
Language name corresponding to the language code used by the phonemizer. Defaults to None.
|
||||
|
||||
backend (str):
|
||||
tf or torch. Defaults to "torch".
|
||||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
custom_symbols = None
|
||||
if style_wav:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
elif CONFIG.has("gst") and CONFIG.gst and not style_wav:
|
||||
if CONFIG.gst.gst_style_input_weights:
|
||||
style_mel = CONFIG.gst.gst_style_input_weights
|
||||
if hasattr(model, "make_symbols"):
|
||||
custom_symbols = model.make_symbols(CONFIG)
|
||||
# preprocess the given text
|
||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
|
||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
model.tokenizer.text_to_ids(text, language=language_id),
|
||||
dtype=np.int32,
|
||||
)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
|
||||
if language_id is not None:
|
||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
if language_id is not None:
|
||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
elif backend in ["tf", "tflite"]:
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
if hasattr(model, "END2END") and model.END2END:
|
||||
wav = model_outputs.squeeze(0)
|
||||
else:
|
||||
model_outputs = model_outputs.squeeze()
|
||||
if model_outputs.ndim == 2: # [T, C_spec]
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
wav = inv_spectrogram(model_outputs, model.ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
wav = trim_silence(wav, model.ap)
|
||||
else: # [T,]
|
||||
wav = model_outputs
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": alignments,
|
||||
|
|
|
@ -1,276 +1 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# adapted from https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
import gruut
|
||||
from gruut_ipa import IPA
|
||||
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
|
||||
from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols
|
||||
|
||||
# pylint: disable=unnecessary-comprehension
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||
|
||||
_symbols = symbols
|
||||
_phonemes = phonemes
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
# Regular expression matching punctuations, ignoring empty space
|
||||
PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
|
||||
|
||||
# Table for str.translate to fix gruut/TTS phoneme mismatch
|
||||
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
|
||||
|
||||
|
||||
def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False):
|
||||
"""Convert graphemes to phonemes.
|
||||
Parameters:
|
||||
text (str): text to phonemize
|
||||
language (str): language of the text
|
||||
Returns:
|
||||
ph (str): phonemes as a string seperated by "|"
|
||||
ph = "ɪ|g|ˈ|z|æ|m|p|ə|l"
|
||||
"""
|
||||
|
||||
# TO REVIEW : How to have a good implementation for this?
|
||||
if language == "zh-CN":
|
||||
ph = chinese_text_to_phonemes(text)
|
||||
return ph
|
||||
|
||||
if language == "ja-jp":
|
||||
ph = japanese_text_to_phonemes(text)
|
||||
return ph
|
||||
|
||||
if not gruut.is_language_supported(language):
|
||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||
|
||||
# Use gruut for phonemization
|
||||
ph_list = []
|
||||
for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes):
|
||||
for word in sentence:
|
||||
if word.is_break:
|
||||
# Use actual character for break phoneme (e.g., comma)
|
||||
if ph_list:
|
||||
# Join with previous word
|
||||
ph_list[-1].append(word.text)
|
||||
else:
|
||||
# First word is punctuation
|
||||
ph_list.append([word.text])
|
||||
elif word.phonemes:
|
||||
# Add phonemes for word
|
||||
word_phonemes = []
|
||||
|
||||
for word_phoneme in word.phonemes:
|
||||
if not keep_stress:
|
||||
# Remove primary/secondary stress
|
||||
word_phoneme = IPA.without_stress(word_phoneme)
|
||||
|
||||
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
|
||||
|
||||
if word_phoneme:
|
||||
# Flatten phonemes
|
||||
word_phonemes.extend(word_phoneme)
|
||||
|
||||
if word_phonemes:
|
||||
ph_list.append(word_phonemes)
|
||||
|
||||
# Join and re-split to break apart dipthongs, suprasegmentals, etc.
|
||||
ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list]
|
||||
ph = "| ".join(ph_words)
|
||||
|
||||
return ph
|
||||
|
||||
|
||||
def intersperse(sequence, token):
|
||||
result = [token] * (len(sequence) * 2 + 1)
|
||||
result[1::2] = sequence
|
||||
return result
|
||||
|
||||
|
||||
def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _bos, _eos
|
||||
if tp:
|
||||
_bos = tp["bos"]
|
||||
_eos = tp["eos"]
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
|
||||
|
||||
|
||||
def phoneme_to_sequence(
|
||||
text: str,
|
||||
cleaner_names: List[str],
|
||||
language: str,
|
||||
enable_eos_bos: bool = False,
|
||||
custom_symbols: List[str] = None,
|
||||
tp: Dict = None,
|
||||
add_blank: bool = False,
|
||||
use_espeak_phonemes: bool = False,
|
||||
) -> List[int]:
|
||||
"""Converts a string of phonemes to a sequence of IDs.
|
||||
If `custom_symbols` is provided, it will override the default symbols.
|
||||
|
||||
Args:
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
language (str): text language key for phonemization.
|
||||
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
|
||||
|
||||
Returns:
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _phonemes
|
||||
|
||||
if custom_symbols is not None:
|
||||
_phonemes = custom_symbols
|
||||
elif tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
sequence = []
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
to_phonemes = text2phone(clean_text, language, use_espeak_phonemes=use_espeak_phonemes)
|
||||
if to_phonemes is None:
|
||||
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
||||
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
|
||||
for phoneme in filter(None, to_phonemes.split("|")):
|
||||
sequence += _phoneme_to_sequence(phoneme)
|
||||
# Append EOS char
|
||||
if enable_eos_bos:
|
||||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
|
||||
# pylint: disable=global-statement
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
global _id_to_phonemes, _phonemes
|
||||
if add_blank:
|
||||
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
||||
result = ""
|
||||
|
||||
if custom_symbols is not None:
|
||||
_phonemes = custom_symbols
|
||||
elif tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_phonemes:
|
||||
s = _id_to_phonemes[symbol_id]
|
||||
result += s
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def text_to_sequence(
|
||||
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||
) -> List[int]:
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
If `custom_symbols` is provided, it will override the default symbols.
|
||||
|
||||
Args:
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
|
||||
Returns:
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id, _symbols
|
||||
|
||||
if custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
elif tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while text:
|
||||
m = _CURLY_RE.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
# pylint: disable=global-statement
|
||||
global _id_to_symbol, _symbols
|
||||
if add_blank:
|
||||
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
||||
|
||||
if custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
elif tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception("Unknown cleaner: %s" % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(syms):
|
||||
return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _phoneme_to_sequence(phons):
|
||||
return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s not in ["~", "^", "_"]
|
||||
|
||||
|
||||
def _should_keep_phoneme(p):
|
||||
return p in _phonemes_to_id and p not in ["~", "^", "_"]
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
|
|
@ -0,0 +1,468 @@
|
|||
from dataclasses import replace
|
||||
from typing import Dict
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
||||
|
||||
def parse_symbols():
|
||||
return {
|
||||
"pad": _pad,
|
||||
"eos": _eos,
|
||||
"bos": _bos,
|
||||
"characters": _characters,
|
||||
"punctuations": _punctuations,
|
||||
"phonemes": _phonemes,
|
||||
}
|
||||
|
||||
|
||||
# DEFAULT SET OF GRAPHEMES
|
||||
_pad = "<PAD>"
|
||||
_eos = "<EOS>"
|
||||
_bos = "<BOS>"
|
||||
_blank = "<BLNK>" # TODO: check if we need this alongside with PAD
|
||||
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
_punctuations = "!'(),-.:;? "
|
||||
|
||||
|
||||
# DEFAULT SET OF IPA PHONEMES
|
||||
# Phonemes definition (All IPA characters)
|
||||
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
|
||||
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
|
||||
_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
|
||||
_suprasegmentals = "ˈˌːˑ"
|
||||
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
|
||||
_diacrilics = "ɚ˞ɫ"
|
||||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
|
||||
class BaseVocabulary:
|
||||
"""Base Vocabulary class.
|
||||
|
||||
This class only needs a vocabulary dictionary without specifying the characters.
|
||||
|
||||
Args:
|
||||
vocab (Dict): A dictionary of characters and their corresponding indices.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
|
||||
self.vocab = vocab
|
||||
self.pad = pad
|
||||
self.blank = blank
|
||||
self.bos = bos
|
||||
self.eos = eos
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Return the index of the padding character. If the padding character is not specified, return the length
|
||||
of the vocabulary."""
|
||||
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||
|
||||
@property
|
||||
def blank_id(self) -> int:
|
||||
"""Return the index of the blank character. If the blank character is not specified, return the length of
|
||||
the vocabulary."""
|
||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
"""Return the vocabulary dictionary."""
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
"""Set the vocabulary dictionary and character mapping dictionaries."""
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config, **kwargs):
|
||||
"""Initialize from the given config."""
|
||||
if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
|
||||
return (
|
||||
BaseVocabulary(
|
||||
config.characters.vocab_dict,
|
||||
config.characters.pad,
|
||||
config.characters.blank,
|
||||
config.characters.bos,
|
||||
config.characters.eos,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return BaseVocabulary(**kwargs), config
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
"""Return number of tokens in the vocabulary."""
|
||||
return len(self._vocab)
|
||||
|
||||
def char_to_id(self, char: str) -> int:
|
||||
"""Map a character to an token ID."""
|
||||
try:
|
||||
return self._char_to_id[char]
|
||||
except KeyError as e:
|
||||
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
||||
|
||||
def id_to_char(self, idx: int) -> str:
|
||||
"""Map an token ID to a character."""
|
||||
return self._id_to_char[idx]
|
||||
|
||||
|
||||
class BaseCharacters:
|
||||
"""🐸BaseCharacters class
|
||||
|
||||
Every new character class should inherit from this.
|
||||
|
||||
Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```.
|
||||
|
||||
If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method.
|
||||
|
||||
Args:
|
||||
characters (str):
|
||||
Main set of characters to be used in the vocabulary.
|
||||
|
||||
punctuations (str):
|
||||
Characters to be treated as punctuation.
|
||||
|
||||
pad (str):
|
||||
Special padding character that would be ignored by the model.
|
||||
|
||||
eos (str):
|
||||
End of the sentence character.
|
||||
|
||||
bos (str):
|
||||
Beginning of the sentence character.
|
||||
|
||||
blank (str):
|
||||
Optional character used between characters by some models for better prosody.
|
||||
|
||||
is_unique (bool):
|
||||
Remove duplicates from the provided characters. Defaults to True.
|
||||
el
|
||||
is_sorted (bool):
|
||||
Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characters: str = None,
|
||||
punctuations: str = None,
|
||||
pad: str = None,
|
||||
eos: str = None,
|
||||
bos: str = None,
|
||||
blank: str = None,
|
||||
is_unique: bool = False,
|
||||
is_sorted: bool = True,
|
||||
) -> None:
|
||||
self._characters = characters
|
||||
self._punctuations = punctuations
|
||||
self._pad = pad
|
||||
self._eos = eos
|
||||
self._bos = bos
|
||||
self._blank = blank
|
||||
self.is_unique = is_unique
|
||||
self.is_sorted = is_sorted
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
||||
|
||||
@property
|
||||
def blank_id(self) -> int:
|
||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||
|
||||
@property
|
||||
def characters(self):
|
||||
return self._characters
|
||||
|
||||
@characters.setter
|
||||
def characters(self, characters):
|
||||
self._characters = characters
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def punctuations(self):
|
||||
return self._punctuations
|
||||
|
||||
@punctuations.setter
|
||||
def punctuations(self, punctuations):
|
||||
self._punctuations = punctuations
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
return self._pad
|
||||
|
||||
@pad.setter
|
||||
def pad(self, pad):
|
||||
self._pad = pad
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return self._eos
|
||||
|
||||
@eos.setter
|
||||
def eos(self, eos):
|
||||
self._eos = eos
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return self._bos
|
||||
|
||||
@bos.setter
|
||||
def bos(self, bos):
|
||||
self._bos = bos
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def blank(self):
|
||||
return self._blank
|
||||
|
||||
@blank.setter
|
||||
def blank(self, blank):
|
||||
self._blank = blank
|
||||
self._create_vocab()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
return len(self._vocab)
|
||||
|
||||
def _create_vocab(self):
|
||||
_vocab = self._characters
|
||||
if self.is_unique:
|
||||
_vocab = list(set(_vocab))
|
||||
if self.is_sorted:
|
||||
_vocab = sorted(_vocab)
|
||||
_vocab = list(_vocab)
|
||||
_vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
|
||||
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
|
||||
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
|
||||
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
|
||||
self.vocab = _vocab + list(self._punctuations)
|
||||
if self.is_unique:
|
||||
duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
|
||||
assert (
|
||||
len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
|
||||
), f" [!] There are duplicate characters in the character set. {duplicates}"
|
||||
|
||||
def char_to_id(self, char: str) -> int:
|
||||
try:
|
||||
return self._char_to_id[char]
|
||||
except KeyError as e:
|
||||
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
||||
|
||||
def id_to_char(self, idx: int) -> str:
|
||||
return self._id_to_char[idx]
|
||||
|
||||
def print_log(self, level: int = 0):
|
||||
"""
|
||||
Prints the vocabulary in a nice format.
|
||||
"""
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > Characters: {self._characters}")
|
||||
print(f"{indent}| > Punctuations: {self._punctuations}")
|
||||
print(f"{indent}| > Pad: {self._pad}")
|
||||
print(f"{indent}| > EOS: {self._eos}")
|
||||
print(f"{indent}| > BOS: {self._bos}")
|
||||
print(f"{indent}| > Blank: {self._blank}")
|
||||
print(f"{indent}| > Vocab: {self.vocab}")
|
||||
print(f"{indent}| > Num chars: {self.num_chars}")
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument
|
||||
"""Init your character class from a config.
|
||||
|
||||
Implement this method for your subclass.
|
||||
"""
|
||||
# use character set from config
|
||||
if config.characters is not None:
|
||||
return BaseCharacters(**config.characters), config
|
||||
# return default character set
|
||||
characters = BaseCharacters()
|
||||
new_config = replace(config, characters=characters.to_config())
|
||||
return characters, new_config
|
||||
|
||||
def to_config(self) -> "CharactersConfig":
|
||||
return CharactersConfig(
|
||||
characters=self._characters,
|
||||
punctuations=self._punctuations,
|
||||
pad=self._pad,
|
||||
eos=self._eos,
|
||||
bos=self._bos,
|
||||
blank=self._blank,
|
||||
is_unique=self.is_unique,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
|
||||
class IPAPhonemes(BaseCharacters):
|
||||
"""🐸IPAPhonemes class to manage `TTS.tts` model vocabulary
|
||||
|
||||
Intended to be used with models using IPAPhonemes as input.
|
||||
It uses system defaults for the undefined class arguments.
|
||||
|
||||
Args:
|
||||
characters (str):
|
||||
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`.
|
||||
|
||||
punctuations (str):
|
||||
Characters to be treated as punctuation. Defaults to `_punctuations`.
|
||||
|
||||
pad (str):
|
||||
Special padding character that would be ignored by the model. Defaults to `_pad`.
|
||||
|
||||
eos (str):
|
||||
End of the sentence character. Defaults to `_eos`.
|
||||
|
||||
bos (str):
|
||||
Beginning of the sentence character. Defaults to `_bos`.
|
||||
|
||||
blank (str):
|
||||
Optional character used between characters by some models for better prosody. Defaults to `_blank`.
|
||||
|
||||
is_unique (bool):
|
||||
Remove duplicates from the provided characters. Defaults to True.
|
||||
|
||||
is_sorted (bool):
|
||||
Sort the characters in alphabetical order. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characters: str = _phonemes,
|
||||
punctuations: str = _punctuations,
|
||||
pad: str = _pad,
|
||||
eos: str = _eos,
|
||||
bos: str = _bos,
|
||||
blank: str = _blank,
|
||||
is_unique: bool = False,
|
||||
is_sorted: bool = True,
|
||||
) -> None:
|
||||
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit"):
|
||||
"""Init a IPAPhonemes object from a model config
|
||||
|
||||
If characters are not defined in the config, it will be set to the default characters and the config
|
||||
will be updated.
|
||||
"""
|
||||
# band-aid for compatibility with old models
|
||||
if "characters" in config and config.characters is not None:
|
||||
if "phonemes" in config.characters and config.characters.phonemes is not None:
|
||||
config.characters["characters"] = config.characters["phonemes"]
|
||||
return (
|
||||
IPAPhonemes(
|
||||
characters=config.characters["characters"],
|
||||
punctuations=config.characters["punctuations"],
|
||||
pad=config.characters["pad"],
|
||||
eos=config.characters["eos"],
|
||||
bos=config.characters["bos"],
|
||||
blank=config.characters["blank"],
|
||||
is_unique=config.characters["is_unique"],
|
||||
is_sorted=config.characters["is_sorted"],
|
||||
),
|
||||
config,
|
||||
)
|
||||
# use character set from config
|
||||
if config.characters is not None:
|
||||
return IPAPhonemes(**config.characters), config
|
||||
# return default character set
|
||||
characters = IPAPhonemes()
|
||||
new_config = replace(config, characters=characters.to_config())
|
||||
return characters, new_config
|
||||
|
||||
|
||||
class Graphemes(BaseCharacters):
|
||||
"""🐸Graphemes class to manage `TTS.tts` model vocabulary
|
||||
|
||||
Intended to be used with models using graphemes as input.
|
||||
It uses system defaults for the undefined class arguments.
|
||||
|
||||
Args:
|
||||
characters (str):
|
||||
Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`.
|
||||
|
||||
punctuations (str):
|
||||
Characters to be treated as punctuation. Defaults to `_punctuations`.
|
||||
|
||||
pad (str):
|
||||
Special padding character that would be ignored by the model. Defaults to `_pad`.
|
||||
|
||||
eos (str):
|
||||
End of the sentence character. Defaults to `_eos`.
|
||||
|
||||
bos (str):
|
||||
Beginning of the sentence character. Defaults to `_bos`.
|
||||
|
||||
is_unique (bool):
|
||||
Remove duplicates from the provided characters. Defaults to True.
|
||||
|
||||
is_sorted (bool):
|
||||
Sort the characters in alphabetical order. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characters: str = _characters,
|
||||
punctuations: str = _punctuations,
|
||||
pad: str = _pad,
|
||||
eos: str = _eos,
|
||||
bos: str = _bos,
|
||||
blank: str = _blank,
|
||||
is_unique: bool = False,
|
||||
is_sorted: bool = True,
|
||||
) -> None:
|
||||
super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit"):
|
||||
"""Init a Graphemes object from a model config
|
||||
|
||||
If characters are not defined in the config, it will be set to the default characters and the config
|
||||
will be updated.
|
||||
"""
|
||||
if config.characters is not None:
|
||||
# band-aid for compatibility with old models
|
||||
if "phonemes" in config.characters:
|
||||
return (
|
||||
Graphemes(
|
||||
characters=config.characters["characters"],
|
||||
punctuations=config.characters["punctuations"],
|
||||
pad=config.characters["pad"],
|
||||
eos=config.characters["eos"],
|
||||
bos=config.characters["bos"],
|
||||
blank=config.characters["blank"],
|
||||
is_unique=config.characters["is_unique"],
|
||||
is_sorted=config.characters["is_sorted"],
|
||||
),
|
||||
config,
|
||||
)
|
||||
return Graphemes(**config.characters), config
|
||||
characters = Graphemes()
|
||||
new_config = replace(config, characters=characters.to_config())
|
||||
return characters, new_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gr = Graphemes()
|
||||
ph = IPAPhonemes()
|
||||
gr.print_log()
|
||||
ph.print_log()
|
|
@ -19,7 +19,7 @@ def _chinese_pinyin_to_phoneme(pinyin: str) -> str:
|
|||
return phoneme + tone
|
||||
|
||||
|
||||
def chinese_text_to_phonemes(text: str) -> str:
|
||||
def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str:
|
||||
tokenized_text = jieba.cut(text, HMM=False)
|
||||
tokenized_text = " ".join(tokenized_text)
|
||||
pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text)
|
||||
|
@ -34,4 +34,4 @@ def chinese_text_to_phonemes(text: str) -> str:
|
|||
else: # is ponctuation or other
|
||||
results += list(token)
|
||||
|
||||
return "|".join(results)
|
||||
return seperator.join(results)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
"""Set of default text cleaners"""
|
||||
# TODO: pick the cleaner for languages dynamically
|
||||
|
||||
import re
|
||||
|
||||
from anyascii import anyascii
|
||||
|
||||
from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text
|
||||
|
||||
from .abbreviations import abbreviations_en, abbreviations_fr
|
||||
from .number_norm import normalize_numbers
|
||||
from .time import expand_time_english
|
||||
from .english.abbreviations import abbreviations_en
|
||||
from .english.number_norm import normalize_numbers as en_normalize_numbers
|
||||
from .english.time_norm import expand_time_english
|
||||
from .french.abbreviations import abbreviations_fr
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
@ -22,10 +26,6 @@ def expand_abbreviations(text, lang="en"):
|
|||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
@ -92,7 +92,17 @@ def english_cleaners(text):
|
|||
# text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_time_english(text)
|
||||
text = expand_numbers(text)
|
||||
text = en_normalize_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def phoneme_cleaners(text):
|
||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||
text = en_normalize_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
|
@ -126,17 +136,6 @@ def chinese_mandarin_cleaners(text: str) -> str:
|
|||
return text
|
||||
|
||||
|
||||
def phoneme_cleaners(text):
|
||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||
text = expand_numbers(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def multilingual_cleaners(text):
|
||||
"""Pipeline for multilingual text"""
|
||||
text = lowercase(text)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
import re
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations in english:
|
||||
abbreviations_en = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
|
@ -1,30 +1,5 @@
|
|||
import re
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations in english:
|
||||
abbreviations_en = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations in french:
|
||||
abbreviations_fr = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
|
@ -0,0 +1,57 @@
|
|||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer
|
||||
|
||||
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)}
|
||||
|
||||
|
||||
ESPEAK_LANGS = list(ESpeak.supported_languages().keys())
|
||||
GRUUT_LANGS = list(Gruut.supported_languages())
|
||||
|
||||
|
||||
# Dict setting default phonemizers for each language
|
||||
DEF_LANG_TO_PHONEMIZER = {
|
||||
"ja-jp": JA_JP_Phonemizer.name(),
|
||||
"zh-cn": ZH_CN_Phonemizer.name(),
|
||||
}
|
||||
|
||||
|
||||
# Add Gruut languages
|
||||
_ = [Gruut.name()] * len(GRUUT_LANGS)
|
||||
_new_dict = dict(list(zip(GRUUT_LANGS, _)))
|
||||
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
|
||||
|
||||
|
||||
# Add ESpeak languages and override any existing ones
|
||||
_ = [ESpeak.name()] * len(ESPEAK_LANGS)
|
||||
_new_dict = dict(list(zip(list(ESPEAK_LANGS), _)))
|
||||
DEF_LANG_TO_PHONEMIZER.update(_new_dict)
|
||||
|
||||
DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"]
|
||||
|
||||
|
||||
def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
|
||||
"""Initiate a phonemizer by name
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
Name of the phonemizer that should match `phonemizer.name()`.
|
||||
|
||||
kwargs (dict):
|
||||
Extra keyword arguments that should be passed to the phonemizer.
|
||||
"""
|
||||
if name == "espeak":
|
||||
return ESpeak(**kwargs)
|
||||
if name == "gruut":
|
||||
return Gruut(**kwargs)
|
||||
if name == "zh_cn_phonemizer":
|
||||
return ZH_CN_Phonemizer(**kwargs)
|
||||
if name == "ja_jp_phonemizer":
|
||||
return JA_JP_Phonemizer(**kwargs)
|
||||
raise ValueError(f"Phonemizer {name} not found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DEF_LANG_TO_PHONEMIZER)
|
|
@ -0,0 +1,141 @@
|
|||
import abc
|
||||
from typing import List, Tuple
|
||||
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
|
||||
class BasePhonemizer(abc.ABC):
|
||||
"""Base phonemizer class
|
||||
|
||||
Phonemization follows the following steps:
|
||||
1. Preprocessing:
|
||||
- remove empty lines
|
||||
- remove punctuation
|
||||
- keep track of punctuation marks
|
||||
|
||||
2. Phonemization:
|
||||
- convert text to phonemes
|
||||
|
||||
3. Postprocessing:
|
||||
- join phonemes
|
||||
- restore punctuation marks
|
||||
|
||||
Args:
|
||||
language (str):
|
||||
Language used by the phonemizer.
|
||||
|
||||
punctuations (List[str]):
|
||||
List of punctuation marks to be preserved.
|
||||
|
||||
keep_puncs (bool):
|
||||
Whether to preserve punctuation marks or not.
|
||||
"""
|
||||
|
||||
def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
|
||||
|
||||
# ensure the backend is installed on the system
|
||||
if not self.is_available():
|
||||
raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
|
||||
|
||||
# ensure the backend support the requested language
|
||||
self._language = self._init_language(language)
|
||||
|
||||
# setup punctuation processing
|
||||
self._keep_puncs = keep_puncs
|
||||
self._punctuator = Punctuation(punctuations)
|
||||
|
||||
def _init_language(self, language):
|
||||
"""Language initialization
|
||||
|
||||
This method may be overloaded in child classes (see Segments backend)
|
||||
|
||||
"""
|
||||
if not self.is_supported_language(language):
|
||||
raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
|
||||
return language
|
||||
|
||||
@property
|
||||
def language(self):
|
||||
"""The language code configured to be used for phonemization"""
|
||||
return self._language
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def name():
|
||||
"""The name of the backend"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def is_available(cls):
|
||||
"""Returns True if the backend is installed, False otherwise"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def version(cls):
|
||||
"""Return the backend version as a tuple (major, minor, patch)"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def supported_languages():
|
||||
"""Return a dict of language codes -> name supported by the backend"""
|
||||
...
|
||||
|
||||
def is_supported_language(self, language):
|
||||
"""Returns True if `language` is supported by the backend"""
|
||||
return language in self.supported_languages()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _phonemize(self, text, separator):
|
||||
"""The main phonemization method"""
|
||||
|
||||
def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
|
||||
"""Preprocess the text before phonemization
|
||||
|
||||
1. remove spaces
|
||||
2. remove punctuation
|
||||
|
||||
Override this if you need a different behaviour
|
||||
"""
|
||||
text = text.strip()
|
||||
if self._keep_puncs:
|
||||
# a tuple (text, punctuation marks)
|
||||
return self._punctuator.strip_to_restore(text)
|
||||
return [self._punctuator.strip(text)], []
|
||||
|
||||
def _phonemize_postprocess(self, phonemized, punctuations) -> str:
|
||||
"""Postprocess the raw phonemized output
|
||||
|
||||
Override this if you need a different behaviour
|
||||
"""
|
||||
if self._keep_puncs:
|
||||
return self._punctuator.restore(phonemized, punctuations)[0]
|
||||
return phonemized[0]
|
||||
|
||||
def phonemize(self, text: str, separator="|") -> str:
|
||||
"""Returns the `text` phonemized for the given language
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Text to be phonemized.
|
||||
|
||||
separator (str):
|
||||
string separator used between phonemes. Default to '_'.
|
||||
|
||||
Returns:
|
||||
(str): Phonemized text
|
||||
"""
|
||||
text, punctuations = self._phonemize_preprocess(text)
|
||||
phonemized = []
|
||||
for t in text:
|
||||
p = self._phonemize(t, separator)
|
||||
phonemized.append(p)
|
||||
phonemized = self._phonemize_postprocess(phonemized, punctuations)
|
||||
return phonemized
|
||||
|
||||
def print_logs(self, level: int = 0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > phoneme language: {self.language}")
|
||||
print(f"{indent}| > phoneme backend: {self.name()}")
|
|
@ -0,0 +1,225 @@
|
|||
import logging
|
||||
import subprocess
|
||||
from typing import Dict, List
|
||||
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
|
||||
def is_tool(name):
|
||||
from shutil import which
|
||||
|
||||
return which(name) is not None
|
||||
|
||||
|
||||
# priority: espeakng > espeak
|
||||
if is_tool("espeak-ng"):
|
||||
_DEF_ESPEAK_LIB = "espeak-ng"
|
||||
elif is_tool("espeak"):
|
||||
_DEF_ESPEAK_LIB = "espeak"
|
||||
else:
|
||||
_DEF_ESPEAK_LIB = None
|
||||
|
||||
|
||||
def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
|
||||
"""Run espeak with the given arguments."""
|
||||
cmd = [
|
||||
espeak_lib,
|
||||
"-q",
|
||||
"-b",
|
||||
"1", # UTF8 text encoding
|
||||
]
|
||||
cmd.extend(args)
|
||||
logging.debug("espeakng: executing %s", repr(cmd))
|
||||
|
||||
with subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as p:
|
||||
res = iter(p.stdout.readline, b"")
|
||||
if not sync:
|
||||
p.stdout.close()
|
||||
if p.stderr:
|
||||
p.stderr.close()
|
||||
if p.stdin:
|
||||
p.stdin.close()
|
||||
return res
|
||||
res2 = []
|
||||
for line in res:
|
||||
res2.append(line)
|
||||
p.stdout.close()
|
||||
if p.stderr:
|
||||
p.stderr.close()
|
||||
if p.stdin:
|
||||
p.stdin.close()
|
||||
p.wait()
|
||||
return res2
|
||||
|
||||
|
||||
class ESpeak(BasePhonemizer):
|
||||
"""ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P
|
||||
|
||||
Args:
|
||||
language (str):
|
||||
Valid language code for the used backend.
|
||||
|
||||
backend (str):
|
||||
Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically
|
||||
prefering `espeak-ng` over `espeak`. Defaults to None.
|
||||
|
||||
punctuations (str):
|
||||
Characters to be treated as punctuation. Defaults to Punctuation.default_puncs().
|
||||
|
||||
keep_puncs (bool):
|
||||
If True, keep the punctuations after phonemization. Defaults to True.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.utils.text.phonemizers import ESpeak
|
||||
>>> phonemizer = ESpeak("tr")
|
||||
>>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|")
|
||||
'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.'
|
||||
|
||||
"""
|
||||
|
||||
_ESPEAK_LIB = _DEF_ESPEAK_LIB
|
||||
|
||||
def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True):
|
||||
if self._ESPEAK_LIB is None:
|
||||
raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.")
|
||||
self.backend = self._ESPEAK_LIB
|
||||
|
||||
# band-aid for backwards compatibility
|
||||
if language == "en":
|
||||
language = "en-us"
|
||||
|
||||
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
|
||||
if backend is not None:
|
||||
self.backend = backend
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self._ESPEAK_LIB
|
||||
|
||||
@backend.setter
|
||||
def backend(self, backend):
|
||||
if backend not in ["espeak", "espeak-ng"]:
|
||||
raise Exception("Unknown backend: %s" % backend)
|
||||
self._ESPEAK_LIB = backend
|
||||
# skip first two characters of the retuned text
|
||||
# "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
|
||||
# ^^
|
||||
self.num_skip_chars = 2
|
||||
if backend == "espeak-ng":
|
||||
# skip the first character of the retuned text
|
||||
# "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
|
||||
# ^
|
||||
self.num_skip_chars = 1
|
||||
|
||||
def auto_set_espeak_lib(self) -> None:
|
||||
if is_tool("espeak-ng"):
|
||||
self._ESPEAK_LIB = "espeak-ng"
|
||||
elif is_tool("espeak"):
|
||||
self._ESPEAK_LIB = "espeak"
|
||||
else:
|
||||
raise Exception("Cannot set backend automatically. espeak-ng or espeak not found")
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "espeak"
|
||||
|
||||
def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str:
|
||||
"""Convert input text to phonemes.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Text to be converted to phonemes.
|
||||
|
||||
tie (bool, optional) : When True use a '͡' character between
|
||||
consecutive characters of a single phoneme. Else separate phoneme
|
||||
with '_'. This option requires espeak>=1.49. Default to False.
|
||||
"""
|
||||
# set arguments
|
||||
args = ["-v", f"{self._language}"]
|
||||
# espeak and espeak-ng parses `ipa` differently
|
||||
if tie:
|
||||
# use '͡' between phonemes
|
||||
if self.backend == "espeak":
|
||||
args.append("--ipa=1")
|
||||
else:
|
||||
args.append("--ipa=3")
|
||||
else:
|
||||
# split with '_'
|
||||
if self.backend == "espeak":
|
||||
args.append("--ipa=3")
|
||||
else:
|
||||
args.append("--ipa=1")
|
||||
if tie:
|
||||
args.append("--tie=%s" % tie)
|
||||
|
||||
args.append('"' + text + '"')
|
||||
# compute phonemes
|
||||
phonemes = ""
|
||||
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
|
||||
logging.debug("line: %s", repr(line))
|
||||
phonemes += line.decode("utf8").strip()[self.num_skip_chars :] # skip initial redundant characters
|
||||
return phonemes.replace("_", separator)
|
||||
|
||||
def _phonemize(self, text, separator=None):
|
||||
return self.phonemize_espeak(text, separator, tie=False)
|
||||
|
||||
@staticmethod
|
||||
def supported_languages() -> Dict:
|
||||
"""Get a dictionary of supported languages.
|
||||
|
||||
Returns:
|
||||
Dict: Dictionary of language codes.
|
||||
"""
|
||||
if _DEF_ESPEAK_LIB is None:
|
||||
return {}
|
||||
args = ["--voices"]
|
||||
langs = {}
|
||||
count = 0
|
||||
for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True):
|
||||
line = line.decode("utf8").strip()
|
||||
if count > 0:
|
||||
cols = line.split()
|
||||
lang_code = cols[1]
|
||||
lang_name = cols[3]
|
||||
langs[lang_code] = lang_name
|
||||
logging.debug("line: %s", repr(line))
|
||||
count += 1
|
||||
return langs
|
||||
|
||||
def version(self) -> str:
|
||||
"""Get the version of the used backend.
|
||||
|
||||
Returns:
|
||||
str: Version of the used backend.
|
||||
"""
|
||||
args = ["--version"]
|
||||
for line in _espeak_exe(self.backend, args, sync=True):
|
||||
version = line.decode("utf8").strip().split()[2]
|
||||
logging.debug("line: %s", repr(line))
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
def is_available(cls):
|
||||
"""Return true if ESpeak is available else false"""
|
||||
return is_tool("espeak") or is_tool("espeak-ng")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
e = ESpeak(language="en-us")
|
||||
print(e.supported_languages())
|
||||
print(e.version())
|
||||
print(e.language)
|
||||
print(e.name())
|
||||
print(e.is_available())
|
||||
|
||||
e = ESpeak(language="en-us", keep_puncs=False)
|
||||
print("`" + e.phonemize("hello how are you today?") + "`")
|
||||
|
||||
e = ESpeak(language="en-us", keep_puncs=True)
|
||||
print("`" + e.phonemize("hello how are you today?") + "`")
|
|
@ -0,0 +1,151 @@
|
|||
import importlib
|
||||
from typing import List
|
||||
|
||||
import gruut
|
||||
from gruut_ipa import IPA
|
||||
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
# Table for str.translate to fix gruut/TTS phoneme mismatch
|
||||
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
|
||||
|
||||
|
||||
class Gruut(BasePhonemizer):
|
||||
"""Gruut wrapper for G2P
|
||||
|
||||
Args:
|
||||
language (str):
|
||||
Valid language code for the used backend.
|
||||
|
||||
punctuations (str):
|
||||
Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
|
||||
|
||||
keep_puncs (bool):
|
||||
If true, keep the punctuations after phonemization. Defaults to True.
|
||||
|
||||
use_espeak_phonemes (bool):
|
||||
If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
|
||||
|
||||
keep_stress (bool):
|
||||
If true, keep the stress characters after phonemization. Defaults to False.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
>>> phonemizer = Gruut('en-us')
|
||||
>>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
|
||||
'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language: str,
|
||||
punctuations=Punctuation.default_puncs(),
|
||||
keep_puncs=True,
|
||||
use_espeak_phonemes=False,
|
||||
keep_stress=False,
|
||||
):
|
||||
super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
|
||||
self.use_espeak_phonemes = use_espeak_phonemes
|
||||
self.keep_stress = keep_stress
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "gruut"
|
||||
|
||||
def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
|
||||
"""Convert input text to phonemes.
|
||||
|
||||
Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
|
||||
that constitude a single sound.
|
||||
|
||||
It doesn't affect 🐸TTS since it individually converts each character to token IDs.
|
||||
|
||||
Examples::
|
||||
"hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Text to be converted to phonemes.
|
||||
|
||||
tie (bool, optional) : When True use a '͡' character between
|
||||
consecutive characters of a single phoneme. Else separate phoneme
|
||||
with '_'. This option requires espeak>=1.49. Default to False.
|
||||
"""
|
||||
ph_list = []
|
||||
for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
|
||||
for word in sentence:
|
||||
if word.is_break:
|
||||
# Use actual character for break phoneme (e.g., comma)
|
||||
if ph_list:
|
||||
# Join with previous word
|
||||
ph_list[-1].append(word.text)
|
||||
else:
|
||||
# First word is punctuation
|
||||
ph_list.append([word.text])
|
||||
elif word.phonemes:
|
||||
# Add phonemes for word
|
||||
word_phonemes = []
|
||||
|
||||
for word_phoneme in word.phonemes:
|
||||
if not self.keep_stress:
|
||||
# Remove primary/secondary stress
|
||||
word_phoneme = IPA.without_stress(word_phoneme)
|
||||
|
||||
word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
|
||||
|
||||
if word_phoneme:
|
||||
# Flatten phonemes
|
||||
word_phonemes.extend(word_phoneme)
|
||||
|
||||
if word_phonemes:
|
||||
ph_list.append(word_phonemes)
|
||||
|
||||
ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
|
||||
ph = f"{separator} ".join(ph_words)
|
||||
return ph
|
||||
|
||||
def _phonemize(self, text, separator):
|
||||
return self.phonemize_gruut(text, separator, tie=False)
|
||||
|
||||
def is_supported_language(self, language):
|
||||
"""Returns True if `language` is supported by the backend"""
|
||||
return gruut.is_language_supported(language)
|
||||
|
||||
@staticmethod
|
||||
def supported_languages() -> List:
|
||||
"""Get a dictionary of supported languages.
|
||||
|
||||
Returns:
|
||||
List: List of language codes.
|
||||
"""
|
||||
return list(gruut.get_supported_languages())
|
||||
|
||||
def version(self):
|
||||
"""Get the version of the used backend.
|
||||
|
||||
Returns:
|
||||
str: Version of the used backend.
|
||||
"""
|
||||
return gruut.__version__
|
||||
|
||||
@classmethod
|
||||
def is_available(cls):
|
||||
"""Return true if ESpeak is available else false"""
|
||||
return importlib.util.find_spec("gruut") is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
e = Gruut(language="en-us")
|
||||
print(e.supported_languages())
|
||||
print(e.version())
|
||||
print(e.language)
|
||||
print(e.name())
|
||||
print(e.is_available())
|
||||
|
||||
e = Gruut(language="en-us", keep_puncs=False)
|
||||
print("`" + e.phonemize("hello how are you today?") + "`")
|
||||
|
||||
e = Gruut(language="en-us", keep_puncs=True)
|
||||
print("`" + e.phonemize("hello how, are you today?") + "`")
|
|
@ -0,0 +1,72 @@
|
|||
from typing import Dict
|
||||
|
||||
from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
|
||||
_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】"
|
||||
|
||||
_TRANS_TABLE = {"、": ","}
|
||||
|
||||
|
||||
def trans(text):
|
||||
for i, j in _TRANS_TABLE.items():
|
||||
text = text.replace(i, j)
|
||||
return text
|
||||
|
||||
|
||||
class JA_JP_Phonemizer(BasePhonemizer):
|
||||
"""🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer`
|
||||
|
||||
TODO: someone with JA knowledge should check this implementation
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer
|
||||
>>> phonemizer = JA_JP_Phonemizer()
|
||||
>>> phonemizer.phonemize("どちらに行きますか?", separator="|")
|
||||
'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?'
|
||||
|
||||
"""
|
||||
|
||||
language = "ja-jp"
|
||||
|
||||
def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "ja_jp_phonemizer"
|
||||
|
||||
def _phonemize(self, text: str, separator: str = "|") -> str:
|
||||
ph = japanese_text_to_phonemes(text)
|
||||
if separator is not None or separator != "":
|
||||
return separator.join(ph)
|
||||
return ph
|
||||
|
||||
def phonemize(self, text: str, separator="|") -> str:
|
||||
"""Custom phonemize for JP_JA
|
||||
|
||||
Skip pre-post processing steps used by the other phonemizers.
|
||||
"""
|
||||
return self._phonemize(text, separator)
|
||||
|
||||
@staticmethod
|
||||
def supported_languages() -> Dict:
|
||||
return {"ja-jp": "Japanese (Japan)"}
|
||||
|
||||
def version(self) -> str:
|
||||
return "0.0.1"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# text = "これは、電話をかけるための私の日本語の例のテキストです。"
|
||||
# e = JA_JP_Phonemizer()
|
||||
# print(e.supported_languages())
|
||||
# print(e.version())
|
||||
# print(e.language)
|
||||
# print(e.name())
|
||||
# print(e.is_available())
|
||||
# print("`" + e.phonemize(text) + "`")
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
|
||||
|
||||
|
||||
class MultiPhonemizer:
|
||||
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
|
||||
|
||||
Args:
|
||||
custom_lang_to_phonemizer (Dict):
|
||||
Custom phonemizer mapping if you want to change the defaults. In the format of
|
||||
`{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`.
|
||||
|
||||
TODO: find a way to pass custom kwargs to the phonemizers
|
||||
"""
|
||||
|
||||
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER
|
||||
language = "multi-lingual"
|
||||
|
||||
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value
|
||||
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer)
|
||||
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
|
||||
|
||||
@staticmethod
|
||||
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
|
||||
lang_to_phonemizer = {}
|
||||
for k, v in lang_to_phonemizer_name.items():
|
||||
phonemizer = get_phonemizer_by_name(v, language=k)
|
||||
lang_to_phonemizer[k] = phonemizer
|
||||
return lang_to_phonemizer
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "multi-phonemizer"
|
||||
|
||||
def phonemize(self, text, language, separator="|"):
|
||||
return self.lang_to_phonemizer[language].phonemize(text, separator)
|
||||
|
||||
def supported_languages(self) -> List:
|
||||
return list(self.lang_to_phonemizer_name.keys())
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# texts = {
|
||||
# "tr": "Merhaba, bu Türkçe bit örnek!",
|
||||
# "en-us": "Hello, this is English example!",
|
||||
# "de": "Hallo, das ist ein Deutches Beipiel!",
|
||||
# "zh-cn": "这是中国的例子",
|
||||
# }
|
||||
# phonemes = {}
|
||||
# ph = MultiPhonemizer()
|
||||
# for lang, text in texts.items():
|
||||
# phoneme = ph.phonemize(text, lang)
|
||||
# phonemes[lang] = phoneme
|
||||
# print(phonemes)
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Dict
|
||||
|
||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
|
||||
_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】"
|
||||
|
||||
|
||||
class ZH_CN_Phonemizer(BasePhonemizer):
|
||||
"""🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer`
|
||||
|
||||
Args:
|
||||
punctuations (str):
|
||||
Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`.
|
||||
|
||||
keep_puncs (bool):
|
||||
If True, keep the punctuations after phonemization. Defaults to False.
|
||||
|
||||
Example ::
|
||||
|
||||
"这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。`
|
||||
|
||||
TODO: someone with Mandarin knowledge should check this implementation
|
||||
"""
|
||||
|
||||
language = "zh-cn"
|
||||
|
||||
def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "zh_cn_phonemizer"
|
||||
|
||||
@staticmethod
|
||||
def phonemize_zh_cn(text: str, separator: str = "|") -> str:
|
||||
ph = chinese_text_to_phonemes(text, separator)
|
||||
return ph
|
||||
|
||||
def _phonemize(self, text, separator):
|
||||
return self.phonemize_zh_cn(text, separator)
|
||||
|
||||
@staticmethod
|
||||
def supported_languages() -> Dict:
|
||||
return {"zh-cn": "Japanese (Japan)"}
|
||||
|
||||
def version(self) -> str:
|
||||
return "0.0.1"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# text = "这是,样本中文。"
|
||||
# e = ZH_CN_Phonemizer()
|
||||
# print(e.supported_languages())
|
||||
# print(e.version())
|
||||
# print(e.language)
|
||||
# print(e.name())
|
||||
# print(e.is_available())
|
||||
# print("`" + e.phonemize(text) + "`")
|
|
@ -0,0 +1,172 @@
|
|||
import collections
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
|
||||
_DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
|
||||
|
||||
_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
|
||||
|
||||
|
||||
class PuncPosition(Enum):
|
||||
"""Enum for the punctuations positions"""
|
||||
|
||||
BEGIN = 0
|
||||
END = 1
|
||||
MIDDLE = 2
|
||||
ALONE = 3
|
||||
|
||||
|
||||
class Punctuation:
|
||||
"""Handle punctuations in text.
|
||||
|
||||
Just strip punctuations from text or strip and restore them later.
|
||||
|
||||
Args:
|
||||
puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
|
||||
|
||||
Example:
|
||||
>>> punc = Punctuation()
|
||||
>>> punc.strip("This is. example !")
|
||||
'This is example'
|
||||
|
||||
>>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
|
||||
>>> ' '.join(text_striped)
|
||||
'This is example'
|
||||
|
||||
>>> text_restored = punc.restore(text_striped, punc_map)
|
||||
>>> text_restored[0]
|
||||
'This is. example !'
|
||||
"""
|
||||
|
||||
def __init__(self, puncs: str = _DEF_PUNCS):
|
||||
self.puncs = puncs
|
||||
|
||||
@staticmethod
|
||||
def default_puncs():
|
||||
"""Return default set of punctuations."""
|
||||
return _DEF_PUNCS
|
||||
|
||||
@property
|
||||
def puncs(self):
|
||||
return self._puncs
|
||||
|
||||
@puncs.setter
|
||||
def puncs(self, value):
|
||||
if not isinstance(value, six.string_types):
|
||||
raise ValueError("[!] Punctuations must be of type str.")
|
||||
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
|
||||
self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
|
||||
|
||||
def strip(self, text):
|
||||
"""Remove all the punctuations by replacing with `space`.
|
||||
|
||||
Args:
|
||||
text (str): The text to be processed.
|
||||
|
||||
Example::
|
||||
|
||||
"This is. example !" -> "This is example "
|
||||
"""
|
||||
return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
|
||||
|
||||
def strip_to_restore(self, text):
|
||||
"""Remove punctuations from text to restore them later.
|
||||
|
||||
Args:
|
||||
text (str): The text to be processed.
|
||||
|
||||
Examples ::
|
||||
|
||||
"This is. example !" -> [["This is", "example"], [".", "!"]]
|
||||
|
||||
"""
|
||||
text, puncs = self._strip_to_restore(text)
|
||||
return text, puncs
|
||||
|
||||
def _strip_to_restore(self, text):
|
||||
"""Auxiliary method for Punctuation.preserve()"""
|
||||
matches = list(re.finditer(self.puncs_regular_exp, text))
|
||||
if not matches:
|
||||
return [text], []
|
||||
# the text is only punctuations
|
||||
if len(matches) == 1 and matches[0].group() == text:
|
||||
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
|
||||
# build a punctuation map to be used later to restore punctuations
|
||||
puncs = []
|
||||
for match in matches:
|
||||
position = PuncPosition.MIDDLE
|
||||
if match == matches[0] and text.startswith(match.group()):
|
||||
position = PuncPosition.BEGIN
|
||||
elif match == matches[-1] and text.endswith(match.group()):
|
||||
position = PuncPosition.END
|
||||
puncs.append(_PUNC_IDX(match.group(), position))
|
||||
# convert str text to a List[str], each item is separated by a punctuation
|
||||
splitted_text = []
|
||||
for idx, punc in enumerate(puncs):
|
||||
split = text.split(punc.punc)
|
||||
prefix, suffix = split[0], punc.punc.join(split[1:])
|
||||
splitted_text.append(prefix)
|
||||
# if the text does not end with a punctuation, add it to the last item
|
||||
if idx == len(puncs) - 1 and len(suffix) > 0:
|
||||
splitted_text.append(suffix)
|
||||
text = suffix
|
||||
return splitted_text, puncs
|
||||
|
||||
@classmethod
|
||||
def restore(cls, text, puncs):
|
||||
"""Restore punctuation in a text.
|
||||
|
||||
Args:
|
||||
text (str): The text to be processed.
|
||||
puncs (List[str]): The list of punctuations map to be used for restoring.
|
||||
|
||||
Examples ::
|
||||
|
||||
['This is', 'example'], ['.', '!'] -> "This is. example!"
|
||||
|
||||
"""
|
||||
return cls._restore(text, puncs, 0)
|
||||
|
||||
@classmethod
|
||||
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
|
||||
"""Auxiliary method for Punctuation.restore()"""
|
||||
if not puncs:
|
||||
return text
|
||||
|
||||
# nothing have been phonemized, returns the puncs alone
|
||||
if not text:
|
||||
return ["".join(m.mark for m in puncs)]
|
||||
|
||||
current = puncs[0]
|
||||
|
||||
if current.position == PuncPosition.BEGIN:
|
||||
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
|
||||
|
||||
if current.position == PuncPosition.END:
|
||||
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
|
||||
|
||||
if current.position == PuncPosition.ALONE:
|
||||
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
|
||||
|
||||
# POSITION == MIDDLE
|
||||
if len(text) == 1: # pragma: nocover
|
||||
# a corner case where the final part of an intermediate
|
||||
# mark (I) has not been phonemized
|
||||
return cls._restore([text[0] + current.punc], puncs[1:], num)
|
||||
|
||||
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# punc = Punctuation()
|
||||
# text = "This is. This is, example!"
|
||||
|
||||
# print(punc.strip(text))
|
||||
|
||||
# split_text, puncs = punc.strip_to_restore(text)
|
||||
# print(split_text, " ---- ", puncs)
|
||||
|
||||
# restored_text = punc.restore(split_text, puncs)
|
||||
# print(restored_text)
|
|
@ -1,75 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
"""
|
||||
|
||||
|
||||
def make_symbols(
|
||||
characters,
|
||||
phonemes=None,
|
||||
punctuations="!'(),-.:;? ",
|
||||
pad="_",
|
||||
eos="~",
|
||||
bos="^",
|
||||
unique=True,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
"""Function to create symbols and phonemes
|
||||
TODO: create phonemes_to_id and symbols_to_id dicts here."""
|
||||
_symbols = list(characters)
|
||||
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
|
||||
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
|
||||
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
|
||||
_phonemes = None
|
||||
if phonemes is not None:
|
||||
_phonemes_sorted = (
|
||||
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
||||
) # this is to keep previous models compatible.
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
# _arpabet = ["@" + s for s in _phonemes_sorted]
|
||||
# Export all symbols:
|
||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
# _symbols += _arpabet
|
||||
return _symbols, _phonemes
|
||||
|
||||
|
||||
_pad = "_"
|
||||
_eos = "~"
|
||||
_bos = "^"
|
||||
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
|
||||
_punctuations = "!'(),-.:;? "
|
||||
|
||||
# Phonemes definition (All IPA characters)
|
||||
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
|
||||
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
|
||||
_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
|
||||
_suprasegmentals = "ˈˌːˑ"
|
||||
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
|
||||
_diacrilics = "ɚ˞ɫ"
|
||||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
|
||||
|
||||
# Generate ALIEN language
|
||||
# from random import shuffle
|
||||
# shuffle(phonemes)
|
||||
|
||||
|
||||
def parse_symbols():
|
||||
return {
|
||||
"pad": _pad,
|
||||
"eos": _eos,
|
||||
"bos": _bos,
|
||||
"characters": _characters,
|
||||
"punctuations": _punctuations,
|
||||
"phonemes": _phonemes,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(" > TTS symbols {}".format(len(symbols)))
|
||||
print(symbols)
|
||||
print(" > TTS phonemes {}".format(len(phonemes)))
|
||||
print("".join(sorted(phonemes)))
|
|
@ -0,0 +1,205 @@
|
|||
from typing import Callable, Dict, List, Union
|
||||
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
|
||||
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
|
||||
from TTS.utils.generic_utils import get_import_path, import_class
|
||||
|
||||
|
||||
class TTSTokenizer:
|
||||
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
||||
|
||||
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
|
||||
|
||||
Args:
|
||||
use_phonemes (bool):
|
||||
Whether to use phonemes instead of characters. Defaults to False.
|
||||
|
||||
characters (Characters):
|
||||
A Characters object to use for character-to-ID and ID-to-character mappings.
|
||||
|
||||
text_cleaner (callable):
|
||||
A function to pre-process the text before tokenization and phonemization. Defaults to None.
|
||||
|
||||
phonemizer (Phonemizer):
|
||||
A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
>>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes())
|
||||
>>> text = "Hello world!"
|
||||
>>> ids = tokenizer.text_to_ids(text)
|
||||
>>> text_hat = tokenizer.ids_to_text(ids)
|
||||
>>> assert text == text_hat
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_phonemes=False,
|
||||
text_cleaner: Callable = None,
|
||||
characters: "BaseCharacters" = None,
|
||||
phonemizer: Union["Phonemizer", Dict] = None,
|
||||
add_blank: bool = False,
|
||||
use_eos_bos=False,
|
||||
):
|
||||
self.text_cleaner = text_cleaner
|
||||
self.use_phonemes = use_phonemes
|
||||
self.add_blank = add_blank
|
||||
self.use_eos_bos = use_eos_bos
|
||||
self.characters = characters
|
||||
self.not_found_characters = []
|
||||
self.phonemizer = phonemizer
|
||||
|
||||
@property
|
||||
def characters(self):
|
||||
return self._characters
|
||||
|
||||
@characters.setter
|
||||
def characters(self, new_characters):
|
||||
self._characters = new_characters
|
||||
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
|
||||
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""Encodes a string of text as a sequence of IDs."""
|
||||
token_ids = []
|
||||
for char in text:
|
||||
try:
|
||||
idx = self.characters.char_to_id(char)
|
||||
token_ids.append(idx)
|
||||
except KeyError:
|
||||
# discard but store not found characters
|
||||
if char not in self.not_found_characters:
|
||||
self.not_found_characters.append(char)
|
||||
print(text)
|
||||
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
"""Decodes a sequence of IDs to a string of text."""
|
||||
text = ""
|
||||
for token_id in token_ids:
|
||||
text += self.characters.id_to_char(token_id)
|
||||
return text
|
||||
|
||||
def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
|
||||
"""Converts a string of text to a sequence of token IDs.
|
||||
|
||||
Args:
|
||||
text(str):
|
||||
The text to convert to token IDs.
|
||||
|
||||
language(str):
|
||||
The language code of the text. Defaults to None.
|
||||
|
||||
TODO:
|
||||
- Add support for language-specific processing.
|
||||
|
||||
1. Text normalizatin
|
||||
2. Phonemization (if use_phonemes is True)
|
||||
3. Add blank char between characters
|
||||
4. Add BOS and EOS characters
|
||||
5. Text to token IDs
|
||||
"""
|
||||
# TODO: text cleaner should pick the right routine based on the language
|
||||
if self.text_cleaner is not None:
|
||||
text = self.text_cleaner(text)
|
||||
if self.use_phonemes:
|
||||
text = self.phonemizer.phonemize(text, separator="")
|
||||
if self.add_blank:
|
||||
text = self.intersperse_blank_char(text, True)
|
||||
if self.use_eos_bos:
|
||||
text = self.pad_with_bos_eos(text)
|
||||
return self.encode(text)
|
||||
|
||||
def ids_to_text(self, id_sequence: List[int]) -> str:
|
||||
"""Converts a sequence of token IDs to a string of text."""
|
||||
return self.decode(id_sequence)
|
||||
|
||||
def pad_with_bos_eos(self, char_sequence: List[str]):
|
||||
"""Pads a sequence with the special BOS and EOS characters."""
|
||||
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
|
||||
|
||||
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
||||
"""Intersperses the blank character between characters in a sequence.
|
||||
|
||||
Use the ```blank``` character if defined else use the ```pad``` character.
|
||||
"""
|
||||
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
|
||||
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
||||
result[1::2] = char_sequence
|
||||
return result
|
||||
|
||||
def print_logs(self, level: int = 0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > add_blank: {self.add_blank}")
|
||||
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
|
||||
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
|
||||
if self.use_phonemes:
|
||||
print(f"{indent}| > phonemizer:")
|
||||
self.phonemizer.print_logs(level + 1)
|
||||
if len(self.not_found_characters) > 0:
|
||||
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
|
||||
for char in self.not_found_characters:
|
||||
print(f"{indent}| > {char}")
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
|
||||
"""Init Tokenizer object from config
|
||||
|
||||
Args:
|
||||
config (Coqpit): Coqpit model config.
|
||||
characters (BaseCharacters): Defines the model character set. If not set, use the default options based on
|
||||
the config values. Defaults to None.
|
||||
"""
|
||||
# init cleaners
|
||||
text_cleaner = None
|
||||
if isinstance(config.text_cleaner, (str, list)):
|
||||
text_cleaner = getattr(cleaners, config.text_cleaner)
|
||||
|
||||
# init characters
|
||||
if characters is None:
|
||||
# set characters based on defined characters class
|
||||
if config.characters and config.characters.characters_class:
|
||||
CharactersClass = import_class(config.characters.characters_class)
|
||||
characters, new_config = CharactersClass.init_from_config(config)
|
||||
# set characters based on config
|
||||
else:
|
||||
if config.use_phonemes:
|
||||
# init phoneme set
|
||||
characters, new_config = IPAPhonemes().init_from_config(config)
|
||||
else:
|
||||
# init character set
|
||||
characters, new_config = Graphemes().init_from_config(config)
|
||||
|
||||
else:
|
||||
characters, new_config = characters.init_from_config(config)
|
||||
|
||||
# set characters class
|
||||
new_config.characters.characters_class = get_import_path(characters)
|
||||
|
||||
# init phonemizer
|
||||
phonemizer = None
|
||||
if config.use_phonemes:
|
||||
phonemizer_kwargs = {"language": config.phoneme_language}
|
||||
|
||||
if "phonemizer" in config and config.phonemizer:
|
||||
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
|
||||
else:
|
||||
try:
|
||||
phonemizer = get_phonemizer_by_name(
|
||||
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
|
||||
)
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"""No phonemizer found for language {config.phoneme_language}.
|
||||
You may need to install a third party library for this language."""
|
||||
) from e
|
||||
|
||||
return (
|
||||
TTSTokenizer(
|
||||
config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars
|
||||
),
|
||||
new_config,
|
||||
)
|
|
@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
||||
|
@ -89,12 +87,46 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False)
|
|||
return fig
|
||||
|
||||
|
||||
def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
|
||||
"""Plot pitch curves on top of the input characters.
|
||||
|
||||
Args:
|
||||
pitch (np.array): Pitch values.
|
||||
chars (str): Characters to place to the x-axis.
|
||||
|
||||
Shapes:
|
||||
pitch: :math:`(T,)`
|
||||
"""
|
||||
old_fig_size = plt.rcParams["figure.figsize"]
|
||||
if fig_size is not None:
|
||||
plt.rcParams["figure.figsize"] = fig_size
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
x = np.array(range(len(chars)))
|
||||
my_xticks = chars
|
||||
plt.xticks(x, my_xticks)
|
||||
|
||||
ax.set_xlabel("characters")
|
||||
ax.set_ylabel("freq")
|
||||
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(pitch, linewidth=5.0, color="red")
|
||||
ax2.set_ylabel("F0")
|
||||
|
||||
plt.rcParams["figure.figsize"] = old_fig_size
|
||||
if not output_fig:
|
||||
plt.close()
|
||||
return fig
|
||||
|
||||
|
||||
def visualize(
|
||||
alignment,
|
||||
postnet_output,
|
||||
text,
|
||||
hop_length,
|
||||
CONFIG,
|
||||
tokenizer,
|
||||
stop_tokens=None,
|
||||
decoder_output=None,
|
||||
output_path=None,
|
||||
|
@ -117,14 +149,8 @@ def visualize(
|
|||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
# compute phoneme representation and back
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(
|
||||
text,
|
||||
[CONFIG.text_cleaner],
|
||||
CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
||||
)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None)
|
||||
seq = tokenizer.text_to_ids(text)
|
||||
text = tokenizer.ids_to_text(seq)
|
||||
print(text)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
|
|
@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S ** self.power
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
|
@ -239,6 +239,12 @@ class AudioProcessor(object):
|
|||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
pitch_fmin (int, optional):
|
||||
minimum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
pitch_fmax (int, optional):
|
||||
maximum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
spec_gain (int, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
|
@ -300,6 +306,8 @@ class AudioProcessor(object):
|
|||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
pitch_fmax=None,
|
||||
pitch_fmin=None,
|
||||
spec_gain=20,
|
||||
stft_pad_mode="reflect",
|
||||
clip_norm=True,
|
||||
|
@ -333,6 +341,8 @@ class AudioProcessor(object):
|
|||
self.symmetric_norm = symmetric_norm
|
||||
self.mel_fmin = mel_fmin or 0
|
||||
self.mel_fmax = mel_fmax
|
||||
self.pitch_fmin = pitch_fmin
|
||||
self.pitch_fmax = pitch_fmax
|
||||
self.spec_gain = float(spec_gain)
|
||||
self.stft_pad_mode = stft_pad_mode
|
||||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||
|
@ -379,6 +389,12 @@ class AudioProcessor(object):
|
|||
self.clip_norm = None
|
||||
self.symmetric_norm = None
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", verbose=True):
|
||||
if "audio" in config:
|
||||
return AudioProcessor(verbose=verbose, **config.audio)
|
||||
return AudioProcessor(verbose=verbose, **config)
|
||||
|
||||
### setting up the parameters ###
|
||||
def _build_mel_basis(
|
||||
self,
|
||||
|
@ -634,8 +650,8 @@ class AudioProcessor(object):
|
|||
S = self._db_to_amp(S)
|
||||
# Reconstruct phase
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
|
@ -643,8 +659,8 @@ class AudioProcessor(object):
|
|||
S = self._db_to_amp(D)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
@ -720,11 +736,12 @@ class AudioProcessor(object):
|
|||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig(mel_fmax=8000)
|
||||
>>> conf = BaseAudioConfig(pitch_fmax=8000)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % self.hop_length == 0:
|
||||
x = np.pad(x, (0, self.hop_length // 2), mode="reflect")
|
||||
|
@ -732,7 +749,7 @@ class AudioProcessor(object):
|
|||
f0, t = pw.dio(
|
||||
x.astype(np.double),
|
||||
fs=self.sample_rate,
|
||||
f0_ceil=self.mel_fmax,
|
||||
f0_ceil=self.pitch_fmax,
|
||||
frame_period=1000 * self.hop_length / self.sample_rate,
|
||||
)
|
||||
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
||||
|
@ -781,7 +798,7 @@ class AudioProcessor(object):
|
|||
@staticmethod
|
||||
def _rms_norm(wav, db_level=-27):
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
|
@ -853,7 +870,7 @@ class AudioProcessor(object):
|
|||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
||||
mu = 2 ** qc - 1
|
||||
mu = 2**qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
# Quantize signal to the specified number of levels.
|
||||
|
@ -865,13 +882,13 @@ class AudioProcessor(object):
|
|||
@staticmethod
|
||||
def mulaw_decode(wav, qc):
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2 ** qc - 1
|
||||
mu = 2**qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def encode_16bits(x):
|
||||
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
||||
|
@ -884,12 +901,12 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2 ** bits - 1) / 2
|
||||
return (x + 1.0) * (2**bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2 ** bits - 1) - 1
|
||||
return 2 * x / (2**bits - 1) - 1
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
|
|
|
@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
|
|||
|
||||
while True:
|
||||
# Read by chunk to avoid filling memory
|
||||
chunk = file_obj.read(1024 ** 2)
|
||||
chunk = file_obj.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
hash_func.update(chunk)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue