mirror of https://github.com/coqui-ai/TTS.git
commit
c63bb481e9
|
@ -37,8 +37,8 @@ In the worse case provide steps to reproduce the behaviour.
|
||||||
You can either run `TTS/bin/collect_env_info.py`
|
You can either run `TTS/bin/collect_env_info.py`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_details.py
|
wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py
|
||||||
python collect_env_details.py
|
python collect_env_info.py
|
||||||
```
|
```
|
||||||
|
|
||||||
or fill in the fields below manually.
|
or fill in the fields below manually.
|
||||||
|
|
|
@ -22,25 +22,22 @@ jobs:
|
||||||
experimental: [false]
|
experimental: [false]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
architecture: x64
|
architecture: x64
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt-get update
|
||||||
sudo apt install -y git make
|
sudo apt-get install -y git make gcc
|
||||||
sudo apt install -y python3-wheel gcc
|
|
||||||
make system-deps
|
make system-deps
|
||||||
- name: Upgrade pip
|
- name: Install/upgrade Python setup deps
|
||||||
run: python3 -m pip install --upgrade pip
|
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
|
|
|
@ -7,7 +7,7 @@ defaults:
|
||||||
shell:
|
shell:
|
||||||
bash
|
bash
|
||||||
jobs:
|
jobs:
|
||||||
build-package:
|
build-sdist:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
@ -23,10 +23,63 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- run: |
|
- run: |
|
||||||
python -m pip install -U pip setuptools twine toml
|
python -m pip install -U pip setuptools wheel build
|
||||||
python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))' | pip install -r /dev/stdin
|
|
||||||
- run: |
|
- run: |
|
||||||
python setup.py sdist
|
python -m build
|
||||||
|
- run: |
|
||||||
|
pip install dist/*.tar.gz
|
||||||
|
- uses: actions/upload-artifact@v2
|
||||||
|
with:
|
||||||
|
name: sdist
|
||||||
|
path: dist/*.tar.gz
|
||||||
|
build-wheels:
|
||||||
|
runs-on: ubuntu-20.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.6", "3.7", "3.8", "3.9"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- run: |
|
||||||
|
python -m pip install -U pip setuptools wheel build
|
||||||
|
- run: |
|
||||||
|
python -m build
|
||||||
|
- run: |
|
||||||
|
python -m pip install dist/*.whl
|
||||||
|
- uses: actions/upload-artifact@v2
|
||||||
|
with:
|
||||||
|
name: wheel-${{ matrix.python-version }}
|
||||||
|
path: dist/*.whl
|
||||||
|
publish-artifacts:
|
||||||
|
runs-on: ubuntu-20.04
|
||||||
|
needs: [build-sdist, build-wheels]
|
||||||
|
steps:
|
||||||
|
- run: |
|
||||||
|
mkdir dist
|
||||||
|
- uses: actions/download-artifact@v2
|
||||||
|
with:
|
||||||
|
name: "sdist"
|
||||||
|
path: "dist/"
|
||||||
|
- uses: actions/download-artifact@v2
|
||||||
|
with:
|
||||||
|
name: "wheel-3.6"
|
||||||
|
path: "dist/"
|
||||||
|
- uses: actions/download-artifact@v2
|
||||||
|
with:
|
||||||
|
name: "wheel-3.7"
|
||||||
|
path: "dist/"
|
||||||
|
- uses: actions/download-artifact@v2
|
||||||
|
with:
|
||||||
|
name: "wheel-3.8"
|
||||||
|
path: "dist/"
|
||||||
|
- uses: actions/download-artifact@v2
|
||||||
|
with:
|
||||||
|
name: "wheel-3.9"
|
||||||
|
path: "dist/"
|
||||||
|
- run: |
|
||||||
|
ls -lh dist/
|
||||||
- name: Setup PyPI config
|
- name: Setup PyPI config
|
||||||
run: |
|
run: |
|
||||||
cat << EOF > ~/.pypirc
|
cat << EOF > ~/.pypirc
|
||||||
|
@ -34,5 +87,10 @@ jobs:
|
||||||
username=__token__
|
username=__token__
|
||||||
password=${{ secrets.PYPI_TOKEN }}
|
password=${{ secrets.PYPI_TOKEN }}
|
||||||
EOF
|
EOF
|
||||||
|
- uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
- run: |
|
- run: |
|
||||||
twine upload --repository pypi dist/*.tar.gz
|
python -m pip install twine
|
||||||
|
- run: |
|
||||||
|
twine upload --repository pypi dist/*
|
||||||
|
|
|
@ -22,25 +22,22 @@ jobs:
|
||||||
experimental: [false]
|
experimental: [false]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
architecture: x64
|
architecture: x64
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt-get update
|
||||||
sudo apt install -y git make
|
sudo apt-get install -y git make gcc
|
||||||
sudo apt install -y python3-wheel gcc
|
|
||||||
make system-deps
|
make system-deps
|
||||||
- name: Upgrade pip
|
- name: Install/upgrade Python setup deps
|
||||||
run: python3 -m pip install --upgrade pip
|
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
|
|
|
@ -22,25 +22,22 @@ jobs:
|
||||||
experimental: [false]
|
experimental: [false]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
architecture: x64
|
architecture: x64
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt-get update
|
||||||
sudo apt install -y git make
|
sudo apt-get install -y --no-install-recommends git make gcc
|
||||||
sudo apt install -y python3-wheel gcc
|
|
||||||
make system-deps
|
make system-deps
|
||||||
- name: Upgrade pip
|
- name: Install/upgrade Python setup deps
|
||||||
run: python3 -m pip install --upgrade pip
|
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
|
|
|
@ -22,25 +22,22 @@ jobs:
|
||||||
experimental: [false]
|
experimental: [false]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
architecture: x64
|
architecture: x64
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt-get update
|
||||||
sudo apt install -y git make
|
sudo apt-get install -y git make gcc
|
||||||
sudo apt install -y python3-wheel gcc
|
|
||||||
make system-deps
|
make system-deps
|
||||||
- name: Upgrade pip
|
- name: Install/upgrade Python setup deps
|
||||||
run: python3 -m pip install --upgrade pip
|
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
|
|
|
@ -22,25 +22,22 @@ jobs:
|
||||||
experimental: [false]
|
experimental: [false]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: coqui-ai/setup-python@pip-cache-key-py-ver
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
architecture: x64
|
architecture: x64
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt-get update
|
||||||
sudo apt install -y git make
|
sudo apt-get install -y git make gcc
|
||||||
sudo apt install -y python3-wheel gcc
|
|
||||||
make system-deps
|
make system-deps
|
||||||
- name: Upgrade pip
|
- name: Install/upgrade Python setup deps
|
||||||
run: python3 -m pip install --upgrade pip
|
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
|
|
|
@ -128,6 +128,8 @@ core
|
||||||
recipes/WIP/*
|
recipes/WIP/*
|
||||||
recipes/ljspeech/LJSpeech-1.1/*
|
recipes/ljspeech/LJSpeech-1.1/*
|
||||||
recipes/vctk/VCTK/*
|
recipes/vctk/VCTK/*
|
||||||
|
recipes/**/*.npy
|
||||||
|
recipes/**/*.json
|
||||||
VCTK-Corpus-removed-silence/*
|
VCTK-Corpus-removed-silence/*
|
||||||
|
|
||||||
# ignore training logs
|
# ignore training logs
|
||||||
|
@ -161,4 +163,5 @@ speakers.json
|
||||||
internal/*
|
internal/*
|
||||||
*_pitch.npy
|
*_pitch.npy
|
||||||
*_phoneme.npy
|
*_phoneme.npy
|
||||||
wandb
|
wandb
|
||||||
|
depot/*
|
62
README.md
62
README.md
|
@ -1,7 +1,7 @@
|
||||||
# <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
# <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
||||||
|
|
||||||
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
|
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
|
||||||
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
🐸TTS comes with pretrained models, tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
||||||
|
|
||||||
[](https://github.com/coqui-ai/TTS/actions)
|
[](https://github.com/coqui-ai/TTS/actions)
|
||||||
[](https://badge.fury.io/py/TTS)
|
[](https://badge.fury.io/py/TTS)
|
||||||
|
@ -135,6 +135,66 @@ $ make install
|
||||||
|
|
||||||
If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system).
|
If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system).
|
||||||
|
|
||||||
|
## Use TTS
|
||||||
|
|
||||||
|
### Single Speaker Models
|
||||||
|
|
||||||
|
- List provided models:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --list_models
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run TTS with default models:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run a TTS model with its default vocoder model:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with specific TTS and vocoder models from the list:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run your own TTS and Vocoder models:
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
|
||||||
|
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-speaker Models
|
||||||
|
|
||||||
|
- List the available speakers and choose as <speaker_id> among them:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run the multi-speaker TTS model with the target speaker ID:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run your own multi-speaker TTS model:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||||
|
```
|
||||||
|
|
||||||
## Directory Structure
|
## Directory Structure
|
||||||
```
|
```
|
||||||
|- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
|
|- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
|
||||||
|
|
|
@ -1,5 +1,17 @@
|
||||||
{
|
{
|
||||||
"tts_models": {
|
"tts_models": {
|
||||||
|
"multilingual":{
|
||||||
|
"multi-dataset":{
|
||||||
|
"your_tts":{
|
||||||
|
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
|
||||||
|
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
|
||||||
|
"default_vocoder": null,
|
||||||
|
"commit": "e9a1953e",
|
||||||
|
"license": "CC BY-NC-ND 4.0",
|
||||||
|
"contact": "egolge@coqui.ai"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"en": {
|
"en": {
|
||||||
"ek1": {
|
"ek1": {
|
||||||
"tacotron2": {
|
"tacotron2": {
|
||||||
|
@ -149,7 +161,7 @@
|
||||||
"commit": "bdab788d",
|
"commit": "bdab788d",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"contact": "",
|
"contact": "",
|
||||||
"default_vocoder": null
|
"default_vocoder": "vocoder_models/uk/mai/multiband-melgan"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -301,6 +313,17 @@
|
||||||
"commit": "3900448"
|
"commit": "3900448"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"uk": {
|
||||||
|
"mai": {
|
||||||
|
"multiband-melgan": {
|
||||||
|
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/vocoder_models--uk--mai--multiband-melgan.zip",
|
||||||
|
"author":"@robinhad",
|
||||||
|
"commit": "bdab788d",
|
||||||
|
"license": "MIT",
|
||||||
|
"contact": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1 +1 @@
|
||||||
0.4.2
|
0.5.0
|
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
|
@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
|
||||||
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
|
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
|
@ -234,8 +234,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# use eval and training partitions
|
# use eval and training partitions
|
||||||
meta_data = meta_data_train + meta_data_eval
|
meta_data = meta_data_train + meta_data_eval
|
||||||
|
|
||||||
# parse speakers
|
# init speaker manager
|
||||||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
if c.use_speaker_embedding:
|
||||||
|
speaker_manager = SpeakerManager(data_items=meta_data)
|
||||||
|
elif c.use_d_vector_file:
|
||||||
|
speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
|
||||||
|
else:
|
||||||
|
speaker_manager = None
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Find all the unique characters in a dataset"""
|
||||||
|
import argparse
|
||||||
|
import multiprocessing
|
||||||
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
|
from tqdm.contrib.concurrent import process_map
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.utils.text import text2phone
|
||||||
|
|
||||||
|
|
||||||
|
def compute_phonemes(item):
|
||||||
|
try:
|
||||||
|
text = item[0]
|
||||||
|
language = item[-1]
|
||||||
|
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
|
||||||
|
except:
|
||||||
|
return []
|
||||||
|
return list(set(ph))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# pylint: disable=W0601
|
||||||
|
global c
|
||||||
|
# pylint: disable=bad-option-value
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||||
|
"""
|
||||||
|
Example runs:
|
||||||
|
|
||||||
|
python TTS/bin/find_unique_chars.py --config_path config.json
|
||||||
|
""",
|
||||||
|
formatter_class=RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
|
||||||
|
# load all datasets
|
||||||
|
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
|
||||||
|
items = train_items + eval_items
|
||||||
|
print("Num items:", len(items))
|
||||||
|
|
||||||
|
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||||
|
phones = []
|
||||||
|
for ph in phonemes:
|
||||||
|
phones.extend(ph)
|
||||||
|
phones = set(phones)
|
||||||
|
lower_phones = filter(lambda c: c.islower(), phones)
|
||||||
|
phones_force_lower = [c.lower() for c in phones]
|
||||||
|
phones_force_lower = set(phones_force_lower)
|
||||||
|
|
||||||
|
print(f" > Number of unique phonemes: {len(phones)}")
|
||||||
|
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||||
|
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||||
|
print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,89 @@
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from tqdm.contrib.concurrent import process_map
|
||||||
|
|
||||||
|
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
|
||||||
|
|
||||||
|
|
||||||
|
def remove_silence(filepath):
|
||||||
|
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||||
|
# ignore if the file exists
|
||||||
|
if os.path.exists(output_path) and not args.force:
|
||||||
|
return
|
||||||
|
|
||||||
|
# create all directory structure
|
||||||
|
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# load wave
|
||||||
|
audio, sample_rate = read_wave(filepath)
|
||||||
|
|
||||||
|
# get speech segments
|
||||||
|
segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness)
|
||||||
|
|
||||||
|
segments = list(segments)
|
||||||
|
num_segments = len(segments)
|
||||||
|
flag = False
|
||||||
|
# create the output wave
|
||||||
|
if num_segments != 0:
|
||||||
|
for i, segment in reversed(list(enumerate(segments))):
|
||||||
|
if i >= 1:
|
||||||
|
if not flag:
|
||||||
|
concat_segment = segment
|
||||||
|
flag = True
|
||||||
|
else:
|
||||||
|
concat_segment = segment + concat_segment
|
||||||
|
else:
|
||||||
|
if flag:
|
||||||
|
segment = segment + concat_segment
|
||||||
|
# print("Saving: ", output_path)
|
||||||
|
write_wave(output_path, segment, sample_rate)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("> Just Copying the file to:", output_path)
|
||||||
|
# if fail to remove silence just write the file
|
||||||
|
write_wave(output_path, audio, sample_rate)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audios():
|
||||||
|
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
||||||
|
print("> Number of files: ", len(files))
|
||||||
|
if not args.force:
|
||||||
|
print("> Ignoring files that already exist in the output directory.")
|
||||||
|
|
||||||
|
if files:
|
||||||
|
# create threads
|
||||||
|
num_threads = multiprocessing.cpu_count()
|
||||||
|
process_map(remove_silence, files, max_workers=num_threads, chunksize=15)
|
||||||
|
else:
|
||||||
|
print("> No files Found !")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2"
|
||||||
|
)
|
||||||
|
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||||
|
)
|
||||||
|
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
|
||||||
|
parser.add_argument(
|
||||||
|
"-g",
|
||||||
|
"--glob",
|
||||||
|
type=str,
|
||||||
|
default="**/*.wav",
|
||||||
|
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-a",
|
||||||
|
"--aggressiveness",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
preprocess_audios()
|
|
@ -23,72 +23,76 @@ def str2bool(v):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# pylint: disable=bad-option-value
|
description = """Synthesize speech on command line.
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="""Synthesize speech on command line.\n\n"""
|
|
||||||
"""You can either use your trained model or choose a model from the provided list.\n\n"""
|
|
||||||
"""If you don't specify any models, then it uses LJSpeech based English model.\n\n"""
|
|
||||||
"""
|
|
||||||
# Example Runs:
|
|
||||||
|
|
||||||
## Single Speaker Models
|
You can either use your trained model or choose a model from the provided list.
|
||||||
|
|
||||||
- list provided models
|
If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
|
|
||||||
|
## Example Runs
|
||||||
|
|
||||||
|
### Single Speaker Models
|
||||||
|
|
||||||
|
- List provided models:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --list_models
|
$ tts --list_models
|
||||||
```
|
```
|
||||||
|
|
||||||
- run tts with default models.
|
- Run TTS with default models:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin synthesize.py --text "Text for TTS"
|
$ tts --text "Text for TTS"
|
||||||
```
|
```
|
||||||
|
|
||||||
- run a tts model with its default vocoder model.
|
- Run a TTS model with its default vocoder model:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin synthesize.py --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
|
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
|
||||||
```
|
```
|
||||||
|
|
||||||
- run with specific tts and vocoder models from the list
|
- Run with specific TTS and vocoder models from the list:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
||||||
```
|
```
|
||||||
|
|
||||||
- run your own TTS model (Using Griffin-Lim Vocoder)
|
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
|
$ tts --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
- run your own TTS and Vocoder models
|
- Run your own TTS and Vocoder models:
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
|
$ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
|
||||||
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
|
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
|
||||||
```
|
```
|
||||||
|
|
||||||
## MULTI-SPEAKER MODELS
|
### Multi-speaker Models
|
||||||
|
|
||||||
- list the available speakers and choose as <speaker_id> among them.
|
- List the available speakers and choose as <speaker_id> among them:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||||
```
|
```
|
||||||
|
|
||||||
- run the multi-speaker TTS model with the target speaker ID.
|
- Run the multi-speaker TTS model with the target speaker ID:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||||
```
|
```
|
||||||
|
|
||||||
- run your own multi-speaker TTS model.
|
- Run your own multi-speaker TTS model:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./TTS/bin/synthesize.py --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||||
```
|
```
|
||||||
""",
|
"""
|
||||||
|
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
||||||
|
# documentation in sync more easily.
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=description.replace(" ```\n", ""),
|
||||||
formatter_class=RawTextHelpFormatter,
|
formatter_class=RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,7 +102,7 @@ def main():
|
||||||
nargs="?",
|
nargs="?",
|
||||||
const=True,
|
const=True,
|
||||||
default=False,
|
default=False,
|
||||||
help="list available pre-trained tts and vocoder models.",
|
help="list available pre-trained TTS and vocoder models.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
|
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
|
||||||
|
|
||||||
|
@ -107,7 +111,7 @@ def main():
|
||||||
"--model_name",
|
"--model_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="tts_models/en/ljspeech/tacotron2-DDC",
|
default="tts_models/en/ljspeech/tacotron2-DDC",
|
||||||
help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>",
|
help="Name of one of the pre-trained TTS models in format <language>/<dataset>/<model_name>",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocoder_name",
|
"--vocoder_name",
|
||||||
|
@ -148,12 +152,19 @@ def main():
|
||||||
|
|
||||||
# args for multi-speaker synthesis
|
# args for multi-speaker synthesis
|
||||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||||
|
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speaker_idx",
|
"--speaker_idx",
|
||||||
type=str,
|
type=str,
|
||||||
help="Target speaker ID for a multi-speaker TTS model.",
|
help="Target speaker ID for a multi-speaker TTS model.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language_idx",
|
||||||
|
type=str,
|
||||||
|
help="Target language ID for a multi-lingual TTS model.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speaker_wav",
|
"--speaker_wav",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
|
@ -169,6 +180,14 @@ def main():
|
||||||
const=True,
|
const=True,
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list_language_idxs",
|
||||||
|
help="List available language ids for the defined multi-lingual model.",
|
||||||
|
type=str2bool,
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
# aux args
|
# aux args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_spectogram",
|
"--save_spectogram",
|
||||||
|
@ -180,7 +199,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# print the description if either text or list_models is not set
|
# print the description if either text or list_models is not set
|
||||||
if args.text is None and not args.list_models and not args.list_speaker_idxs:
|
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
|
||||||
parser.parse_args(["-h"])
|
parser.parse_args(["-h"])
|
||||||
|
|
||||||
# load model manager
|
# load model manager
|
||||||
|
@ -190,6 +209,7 @@ def main():
|
||||||
model_path = None
|
model_path = None
|
||||||
config_path = None
|
config_path = None
|
||||||
speakers_file_path = None
|
speakers_file_path = None
|
||||||
|
language_ids_file_path = None
|
||||||
vocoder_path = None
|
vocoder_path = None
|
||||||
vocoder_config_path = None
|
vocoder_config_path = None
|
||||||
encoder_path = None
|
encoder_path = None
|
||||||
|
@ -213,6 +233,7 @@ def main():
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
config_path = args.config_path
|
config_path = args.config_path
|
||||||
speakers_file_path = args.speakers_file_path
|
speakers_file_path = args.speakers_file_path
|
||||||
|
language_ids_file_path = args.language_ids_file_path
|
||||||
|
|
||||||
if args.vocoder_path is not None:
|
if args.vocoder_path is not None:
|
||||||
vocoder_path = args.vocoder_path
|
vocoder_path = args.vocoder_path
|
||||||
|
@ -227,6 +248,7 @@ def main():
|
||||||
model_path,
|
model_path,
|
||||||
config_path,
|
config_path,
|
||||||
speakers_file_path,
|
speakers_file_path,
|
||||||
|
language_ids_file_path,
|
||||||
vocoder_path,
|
vocoder_path,
|
||||||
vocoder_config_path,
|
vocoder_config_path,
|
||||||
encoder_path,
|
encoder_path,
|
||||||
|
@ -242,6 +264,14 @@ def main():
|
||||||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# query langauge ids of a multi-lingual model.
|
||||||
|
if args.list_language_idxs:
|
||||||
|
print(
|
||||||
|
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||||
|
)
|
||||||
|
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||||
|
return
|
||||||
|
|
||||||
# check the arguments against a multi-speaker model.
|
# check the arguments against a multi-speaker model.
|
||||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||||
print(
|
print(
|
||||||
|
@ -254,7 +284,7 @@ def main():
|
||||||
print(" > Text: {}".format(args.text))
|
print(" > Text: {}".format(args.text))
|
||||||
|
|
||||||
# kick it
|
# kick it
|
||||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
|
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
print(" > Saving output to {}".format(args.out_path))
|
print(" > Saving output to {}".format(args.out_path))
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||||
from TTS.speaker_encoder.utils.training import init_training
|
from TTS.speaker_encoder.utils.training import init_training
|
||||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
global meta_data_eval
|
global meta_data_eval
|
||||||
|
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
model = setup_model(c)
|
model = setup_speaker_encoder_model(c)
|
||||||
|
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||||
from TTS.trainer import Trainer, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -45,15 +47,39 @@ def main():
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
||||||
# init speaker manager
|
# init speaker manager
|
||||||
if config.use_speaker_embedding:
|
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
||||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
||||||
elif config.use_d_vector_file:
|
if hasattr(config, "model_args"):
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
else:
|
||||||
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
|
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||||
|
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
|
||||||
|
speaker_manager = SpeakerManager(
|
||||||
|
d_vectors_file_path=config.model_args.d_vector_file,
|
||||||
|
encoder_model_path=config.model_args.speaker_encoder_model_path,
|
||||||
|
encoder_config_path=config.model_args.speaker_encoder_config_path,
|
||||||
|
use_cuda=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||||
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
else:
|
else:
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
|
|
||||||
|
if check_config_and_model_args(config, "use_language_embedding", True):
|
||||||
|
language_manager = LanguageManager(config=config)
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
config.model_args.num_languages = language_manager.num_languages
|
||||||
|
else:
|
||||||
|
config.num_languages = language_manager.num_languages
|
||||||
|
else:
|
||||||
|
language_manager = None
|
||||||
|
|
||||||
# init the model from config
|
# init the model from config
|
||||||
model = setup_model(config, speaker_manager)
|
model = setup_model(config, speaker_manager, language_manager)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
|
@ -95,3 +95,38 @@ def load_config(config_path: str) -> None:
|
||||||
config = config_class()
|
config = config_class()
|
||||||
config.from_dict(config_dict)
|
config.from_dict(config_dict)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def check_config_and_model_args(config, arg_name, value):
|
||||||
|
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||||
|
the given value.
|
||||||
|
|
||||||
|
Return False if the argument does not exist in `config.model_args` or `config`.
|
||||||
|
This is to patch up the compatibility between models with and without `model_args`.
|
||||||
|
|
||||||
|
TODO: Remove this in the future with a unified approach.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name] == value
|
||||||
|
if hasattr(config, arg_name):
|
||||||
|
return config[arg_name] == value
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_config_or_model_args(config, arg_name):
|
||||||
|
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name]
|
||||||
|
return config[arg_name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||||
|
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name]
|
||||||
|
if hasattr(config, arg_name):
|
||||||
|
return config[arg_name]
|
||||||
|
return def_val
|
||||||
|
|
|
@ -60,6 +60,12 @@ class BaseAudioConfig(Coqpit):
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
|
|
||||||
|
do_rms_norm (bool, optional):
|
||||||
|
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||||
|
|
||||||
|
db_level (int, optional):
|
||||||
|
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||||
|
|
||||||
power (float):
|
power (float):
|
||||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||||
artifacts in the synthesized voice. Defaults to 1.5.
|
artifacts in the synthesized voice. Defaults to 1.5.
|
||||||
|
@ -116,6 +122,9 @@ class BaseAudioConfig(Coqpit):
|
||||||
# silence trimming
|
# silence trimming
|
||||||
do_trim_silence: bool = True
|
do_trim_silence: bool = True
|
||||||
trim_db: int = 45
|
trim_db: int = 45
|
||||||
|
# rms volume normalization
|
||||||
|
do_rms_norm: bool = False
|
||||||
|
db_level: float = None
|
||||||
# griffin-lim params
|
# griffin-lim params
|
||||||
power: float = 1.5
|
power: float = 1.5
|
||||||
griffin_lim_iters: int = 60
|
griffin_lim_iters: int = 60
|
||||||
|
@ -184,9 +193,12 @@ class BaseDatasetConfig(Coqpit):
|
||||||
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
unused_speakers (List):
|
ignored_speakers (List):
|
||||||
List of speakers IDs that are not used at the training. Default None.
|
List of speakers IDs that are not used at the training. Default None.
|
||||||
|
|
||||||
|
language (str):
|
||||||
|
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to None.
|
||||||
|
|
||||||
meta_file_val (str):
|
meta_file_val (str):
|
||||||
Name of the dataset meta file that defines the instances used at validation.
|
Name of the dataset meta file that defines the instances used at validation.
|
||||||
|
|
||||||
|
@ -198,7 +210,8 @@ class BaseDatasetConfig(Coqpit):
|
||||||
name: str = ""
|
name: str = ""
|
||||||
path: str = ""
|
path: str = ""
|
||||||
meta_file_train: str = ""
|
meta_file_train: str = ""
|
||||||
ununsed_speakers: List[str] = None
|
ignored_speakers: List[str] = None
|
||||||
|
language: str = ""
|
||||||
meta_file_val: str = ""
|
meta_file_val: str = ""
|
||||||
meta_file_attn_mask: str = ""
|
meta_file_attn_mask: str = ""
|
||||||
|
|
||||||
|
@ -335,6 +348,8 @@ class BaseTrainingConfig(Coqpit):
|
||||||
num_loader_workers: int = 0
|
num_loader_workers: int = 0
|
||||||
num_eval_loader_workers: int = 0
|
num_eval_loader_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
|
use_language_weighted_sampler: bool = False
|
||||||
|
|
||||||
# paths
|
# paths
|
||||||
output_path: str = None
|
output_path: str = None
|
||||||
# distributed
|
# distributed
|
||||||
|
|
|
@ -100,7 +100,15 @@ if args.vocoder_path is not None:
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
synthesizer = Synthesizer(
|
synthesizer = Synthesizer(
|
||||||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
tts_checkpoint=model_path,
|
||||||
|
tts_config_path=config_path,
|
||||||
|
tts_speakers_file=speakers_file_path,
|
||||||
|
tts_languages_file=None,
|
||||||
|
vocoder_checkpoint=vocoder_path,
|
||||||
|
vocoder_config=vocoder_config_path,
|
||||||
|
encoder_checkpoint="",
|
||||||
|
encoder_config="",
|
||||||
|
use_cuda=args.use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||||
|
@ -165,7 +173,7 @@ def tts():
|
||||||
|
|
||||||
style_wav = style_wav_uri_to_dict(style_wav)
|
style_wav = style_wav_uri_to_dict(style_wav)
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
wavs = synthesizer.tts(text, speaker_idx=speaker_idx, style_wav=style_wav)
|
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
||||||
out = io.BytesIO()
|
out = io.BytesIO()
|
||||||
synthesizer.save_wav(wavs, out)
|
synthesizer.save_wav(wavs, out)
|
||||||
return send_file(out, mimetype="audio/wav")
|
return send_file(out, mimetype="audio/wav")
|
||||||
|
|
|
@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
|
||||||
feats = torch.stack(feats)
|
feats = torch.stack(feats)
|
||||||
labels = torch.stack(labels)
|
labels = torch.stack(labels)
|
||||||
|
|
||||||
return feats.transpose(1, 2), labels
|
return feats, labels
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.speaker_encoder.models.resnet import PreEmphasis
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,9 +35,22 @@ class LSTMWithoutProjection(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LSTMSpeakerEncoder(nn.Module):
|
class LSTMSpeakerEncoder(nn.Module):
|
||||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
proj_dim=256,
|
||||||
|
lstm_dim=768,
|
||||||
|
num_lstm_layers=3,
|
||||||
|
use_lstm_with_projection=True,
|
||||||
|
use_torch_spec=False,
|
||||||
|
audio_config=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_lstm_with_projection = use_lstm_with_projection
|
self.use_lstm_with_projection = use_lstm_with_projection
|
||||||
|
self.use_torch_spec = use_torch_spec
|
||||||
|
self.audio_config = audio_config
|
||||||
|
self.proj_dim = proj_dim
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
# choise LSTM layer
|
# choise LSTM layer
|
||||||
if use_lstm_with_projection:
|
if use_lstm_with_projection:
|
||||||
|
@ -46,6 +61,38 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||||
|
|
||||||
|
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||||
|
|
||||||
|
if self.use_torch_spec:
|
||||||
|
self.torch_spec = torch.nn.Sequential(
|
||||||
|
PreEmphasis(audio_config["preemphasis"]),
|
||||||
|
# TorchSTFT(
|
||||||
|
# n_fft=audio_config["fft_size"],
|
||||||
|
# hop_length=audio_config["hop_length"],
|
||||||
|
# win_length=audio_config["win_length"],
|
||||||
|
# sample_rate=audio_config["sample_rate"],
|
||||||
|
# window="hamming_window",
|
||||||
|
# mel_fmin=0.0,
|
||||||
|
# mel_fmax=None,
|
||||||
|
# use_htk=True,
|
||||||
|
# do_amp_to_db=False,
|
||||||
|
# n_mels=audio_config["num_mels"],
|
||||||
|
# power=2.0,
|
||||||
|
# use_mel=True,
|
||||||
|
# mel_norm=None,
|
||||||
|
# )
|
||||||
|
torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=audio_config["sample_rate"],
|
||||||
|
n_fft=audio_config["fft_size"],
|
||||||
|
win_length=audio_config["win_length"],
|
||||||
|
hop_length=audio_config["hop_length"],
|
||||||
|
window_fn=torch.hamming_window,
|
||||||
|
n_mels=audio_config["num_mels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.torch_spec = None
|
||||||
|
|
||||||
self._init_layers()
|
self._init_layers()
|
||||||
|
|
||||||
def _init_layers(self):
|
def _init_layers(self):
|
||||||
|
@ -55,22 +102,33 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
elif "weight" in name:
|
elif "weight" in name:
|
||||||
nn.init.xavier_normal_(param)
|
nn.init.xavier_normal_(param)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, l2_norm=True):
|
||||||
# TODO: implement state passing for lstms
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||||
|
to compute the spectrogram on-the-fly.
|
||||||
|
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
if self.use_torch_spec:
|
||||||
|
x.squeeze_(1)
|
||||||
|
x = self.torch_spec(x)
|
||||||
|
x = self.instancenorm(x).transpose(1, 2)
|
||||||
d = self.layers(x)
|
d = self.layers(x)
|
||||||
if self.use_lstm_with_projection:
|
if self.use_lstm_with_projection:
|
||||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
d = d[:, -1]
|
||||||
else:
|
if l2_norm:
|
||||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x):
|
def inference(self, x, l2_norm=True):
|
||||||
d = self.layers.forward(x)
|
d = self.forward(x, l2_norm=l2_norm)
|
||||||
if self.use_lstm_with_projection:
|
|
||||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
|
||||||
else:
|
|
||||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||||
|
|
|
@ -1,10 +1,25 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
# from TTS.utils.audio import TorchSTFT
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
class PreEmphasis(nn.Module):
|
||||||
|
def __init__(self, coefficient=0.97):
|
||||||
|
super().__init__()
|
||||||
|
self.coefficient = coefficient
|
||||||
|
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert len(x.size()) == 2
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||||
|
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
super(SELayer, self).__init__()
|
super(SELayer, self).__init__()
|
||||||
|
@ -70,12 +85,18 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
num_filters=[32, 64, 128, 256],
|
num_filters=[32, 64, 128, 256],
|
||||||
encoder_type="ASP",
|
encoder_type="ASP",
|
||||||
log_input=False,
|
log_input=False,
|
||||||
|
use_torch_spec=False,
|
||||||
|
audio_config=None,
|
||||||
):
|
):
|
||||||
super(ResNetSpeakerEncoder, self).__init__()
|
super(ResNetSpeakerEncoder, self).__init__()
|
||||||
|
|
||||||
self.encoder_type = encoder_type
|
self.encoder_type = encoder_type
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.log_input = log_input
|
self.log_input = log_input
|
||||||
|
self.use_torch_spec = use_torch_spec
|
||||||
|
self.audio_config = audio_config
|
||||||
|
self.proj_dim = proj_dim
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||||
|
@ -88,6 +109,36 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||||
|
|
||||||
|
if self.use_torch_spec:
|
||||||
|
self.torch_spec = torch.nn.Sequential(
|
||||||
|
PreEmphasis(audio_config["preemphasis"]),
|
||||||
|
# TorchSTFT(
|
||||||
|
# n_fft=audio_config["fft_size"],
|
||||||
|
# hop_length=audio_config["hop_length"],
|
||||||
|
# win_length=audio_config["win_length"],
|
||||||
|
# sample_rate=audio_config["sample_rate"],
|
||||||
|
# window="hamming_window",
|
||||||
|
# mel_fmin=0.0,
|
||||||
|
# mel_fmax=None,
|
||||||
|
# use_htk=True,
|
||||||
|
# do_amp_to_db=False,
|
||||||
|
# n_mels=audio_config["num_mels"],
|
||||||
|
# power=2.0,
|
||||||
|
# use_mel=True,
|
||||||
|
# mel_norm=None,
|
||||||
|
# )
|
||||||
|
torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=audio_config["sample_rate"],
|
||||||
|
n_fft=audio_config["fft_size"],
|
||||||
|
win_length=audio_config["win_length"],
|
||||||
|
hop_length=audio_config["hop_length"],
|
||||||
|
window_fn=torch.hamming_window,
|
||||||
|
n_mels=audio_config["num_mels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.torch_spec = None
|
||||||
|
|
||||||
outmap_size = int(self.input_dim / 8)
|
outmap_size = int(self.input_dim / 8)
|
||||||
|
|
||||||
self.attention = nn.Sequential(
|
self.attention = nn.Sequential(
|
||||||
|
@ -140,9 +191,23 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x, l2_norm=False):
|
def forward(self, x, l2_norm=False):
|
||||||
x = x.transpose(1, 2)
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||||
|
to compute the spectrogram on-the-fly.
|
||||||
|
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||||
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x.squeeze_(1)
|
||||||
|
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||||
|
if self.use_torch_spec:
|
||||||
|
x = self.torch_spec(x)
|
||||||
|
|
||||||
if self.log_input:
|
if self.log_input:
|
||||||
x = (x + 1e-6).log()
|
x = (x + 1e-6).log()
|
||||||
x = self.instancenorm(x).unsqueeze(1)
|
x = self.instancenorm(x).unsqueeze(1)
|
||||||
|
@ -175,11 +240,19 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
def inference(self, x, l2_norm=False):
|
||||||
|
return self.forward(x, l2_norm)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a batch of utterances
|
Generate embeddings for a batch of utterances
|
||||||
x: 1xTxD
|
x: 1xTxD
|
||||||
"""
|
"""
|
||||||
|
# map to the waveform size
|
||||||
|
if self.use_torch_spec:
|
||||||
|
num_frames = num_frames * self.audio_config["hop_length"]
|
||||||
|
|
||||||
max_len = x.shape[1]
|
max_len = x.shape[1]
|
||||||
|
|
||||||
if max_len < num_frames:
|
if max_len < num_frames:
|
||||||
|
@ -195,11 +268,10 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
frames_batch.append(frames)
|
frames_batch.append(frames)
|
||||||
|
|
||||||
frames_batch = torch.cat(frames_batch, dim=0)
|
frames_batch = torch.cat(frames_batch, dim=0)
|
||||||
embeddings = self.forward(frames_batch, l2_norm=True)
|
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||||
|
|
||||||
if return_mean:
|
if return_mean:
|
||||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
|
|
|
@ -170,16 +170,24 @@ def to_camel(text):
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_model(c):
|
def setup_speaker_encoder_model(config: "Coqpit"):
|
||||||
if c.model_params["model_name"].lower() == "lstm":
|
if config.model_params["model_name"].lower() == "lstm":
|
||||||
model = LSTMSpeakerEncoder(
|
model = LSTMSpeakerEncoder(
|
||||||
c.model_params["input_dim"],
|
config.model_params["input_dim"],
|
||||||
c.model_params["proj_dim"],
|
config.model_params["proj_dim"],
|
||||||
c.model_params["lstm_dim"],
|
config.model_params["lstm_dim"],
|
||||||
c.model_params["num_lstm_layers"],
|
config.model_params["num_lstm_layers"],
|
||||||
|
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||||
|
audio_config=config.audio,
|
||||||
|
)
|
||||||
|
elif config.model_params["model_name"].lower() == "resnet":
|
||||||
|
model = ResNetSpeakerEncoder(
|
||||||
|
input_dim=config.model_params["input_dim"],
|
||||||
|
proj_dim=config.model_params["proj_dim"],
|
||||||
|
log_input=config.model_params.get("log_input", False),
|
||||||
|
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||||
|
audio_config=config.audio,
|
||||||
)
|
)
|
||||||
elif c.model_params["model_name"].lower() == "resnet":
|
|
||||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ class Trainer:
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
# copy training assets to the output folder
|
# copy training assets to the output folder
|
||||||
copy_model_files(config, output_path, new_fields=None)
|
copy_model_files(config, output_path)
|
||||||
|
|
||||||
# init class members
|
# init class members
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -439,7 +439,7 @@ class Trainer:
|
||||||
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||||
print(" > Restoring Scaler...")
|
print(" > Restoring Scaler...")
|
||||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||||
except (KeyError, RuntimeError):
|
except (KeyError, RuntimeError, ValueError):
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||||
|
|
|
@ -82,8 +82,14 @@ class VitsConfig(BaseTTSConfig):
|
||||||
add_blank (bool):
|
add_blank (bool):
|
||||||
If true, a blank token is added in between every character. Defaults to `True`.
|
If true, a blank token is added in between every character. Defaults to `True`.
|
||||||
|
|
||||||
test_sentences (List[str]):
|
test_sentences (List[List]):
|
||||||
List of sentences to be used for testing.
|
List of sentences with speaker and language information to be used for testing.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language ids file.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
If true, language embedding is used. Defaults to `False`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
@ -117,6 +123,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
feat_loss_alpha: float = 1.0
|
feat_loss_alpha: float = 1.0
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
|
speaker_encoder_loss_alpha: float = 1.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
@ -130,13 +137,13 @@ class VitsConfig(BaseTTSConfig):
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(
|
test_sentences: List[List] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
|
||||||
"Be a voice, not an echo.",
|
["Be a voice, not an echo."],
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
["I'm sorry Dave. I'm afraid I can't do that."],
|
||||||
"This cake is great. It's so delicious and moist.",
|
["This cake is great. It's so delicious and moist."],
|
||||||
"Prior to November 22, 1963.",
|
["Prior to November 22, 1963."],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -146,29 +153,15 @@ class VitsConfig(BaseTTSConfig):
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
speaker_embedding_channels: int = 256
|
speaker_embedding_channels: int = 256
|
||||||
|
language_ids_file: str = None
|
||||||
|
use_language_embedding: bool = False
|
||||||
|
|
||||||
# use d-vectors
|
# use d-vectors
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
d_vector_file: str = False
|
d_vector_file: str = None
|
||||||
d_vector_dim: int = None
|
d_vector_dim: int = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
for key, val in self.model_args.items():
|
||||||
if self.num_speakers > 0:
|
if hasattr(self, key):
|
||||||
self.model_args.num_speakers = self.num_speakers
|
self[key] = val
|
||||||
|
|
||||||
# speaker embedding settings
|
|
||||||
if self.use_speaker_embedding:
|
|
||||||
self.model_args.use_speaker_embedding = True
|
|
||||||
if self.speakers_file:
|
|
||||||
self.model_args.speakers_file = self.speakers_file
|
|
||||||
if self.speaker_embedding_channels:
|
|
||||||
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
|
|
||||||
|
|
||||||
# d-vector settings
|
|
||||||
if self.use_d_vector_file:
|
|
||||||
self.model_args.use_d_vector_file = True
|
|
||||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
|
||||||
self.model_args.d_vector_dim = self.d_vector_dim
|
|
||||||
if self.d_vector_file:
|
|
||||||
self.model_args.d_vector_file = self.d_vector_file
|
|
||||||
|
|
|
@ -67,16 +67,22 @@ def load_tts_samples(
|
||||||
root_path = dataset["path"]
|
root_path = dataset["path"]
|
||||||
meta_file_train = dataset["meta_file_train"]
|
meta_file_train = dataset["meta_file_train"]
|
||||||
meta_file_val = dataset["meta_file_val"]
|
meta_file_val = dataset["meta_file_val"]
|
||||||
|
ignored_speakers = dataset["ignored_speakers"]
|
||||||
|
language = dataset["language"]
|
||||||
|
|
||||||
# setup the right data processor
|
# setup the right data processor
|
||||||
if formatter is None:
|
if formatter is None:
|
||||||
formatter = _get_formatter_by_name(name)
|
formatter = _get_formatter_by_name(name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = formatter(root_path, meta_file_train)
|
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||||
|
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||||
|
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = formatter(root_path, meta_file_val)
|
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||||
|
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
|
|
|
@ -37,6 +37,7 @@ class TTSDataset(Dataset):
|
||||||
enable_eos_bos: bool = False,
|
enable_eos_bos: bool = False,
|
||||||
speaker_id_mapping: Dict = None,
|
speaker_id_mapping: Dict = None,
|
||||||
d_vector_mapping: Dict = None,
|
d_vector_mapping: Dict = None,
|
||||||
|
language_id_mapping: Dict = None,
|
||||||
use_noise_augment: bool = False,
|
use_noise_augment: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -122,7 +123,9 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos = enable_eos_bos
|
self.enable_eos_bos = enable_eos_bos
|
||||||
self.speaker_id_mapping = speaker_id_mapping
|
self.speaker_id_mapping = speaker_id_mapping
|
||||||
self.d_vector_mapping = d_vector_mapping
|
self.d_vector_mapping = d_vector_mapping
|
||||||
|
self.language_id_mapping = language_id_mapping
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
self.rescue_item_idx = 1
|
self.rescue_item_idx = 1
|
||||||
|
@ -197,10 +200,10 @@ class TTSDataset(Dataset):
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
if len(item) == 4:
|
if len(item) == 5:
|
||||||
text, wav_file, speaker_name, attn_file = item
|
text, wav_file, speaker_name, language_name, attn_file = item
|
||||||
else:
|
else:
|
||||||
text, wav_file, speaker_name = item
|
text, wav_file, speaker_name, language_name = item
|
||||||
attn = None
|
attn = None
|
||||||
raw_text = text
|
raw_text = text
|
||||||
|
|
||||||
|
@ -218,7 +221,7 @@ class TTSDataset(Dataset):
|
||||||
self.phoneme_cache_path,
|
self.phoneme_cache_path,
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
language_name if language_name else self.phoneme_language,
|
||||||
self.custom_symbols,
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
|
@ -260,6 +263,7 @@ class TTSDataset(Dataset):
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
"item_idx": self.items[idx][1],
|
"item_idx": self.items[idx][1],
|
||||||
"speaker_name": speaker_name,
|
"speaker_name": speaker_name,
|
||||||
|
"language_name": language_name,
|
||||||
"wav_file_name": os.path.basename(wav_file),
|
"wav_file_name": os.path.basename(wav_file),
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
@ -269,6 +273,9 @@ class TTSDataset(Dataset):
|
||||||
item = args[0]
|
item = args[0]
|
||||||
func_args = args[1]
|
func_args = args[1]
|
||||||
text, wav_file, *_ = item
|
text, wav_file, *_ = item
|
||||||
|
func_args[3] = (
|
||||||
|
item[3] if item[3] else func_args[3]
|
||||||
|
) # override phoneme language if specified by the dataset formatter
|
||||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
|
@ -335,7 +342,6 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||||
|
|
||||||
# sort items based on the sequence length in ascending order
|
|
||||||
idxs = np.argsort(lengths)
|
idxs = np.argsort(lengths)
|
||||||
new_items = []
|
new_items = []
|
||||||
ignored = []
|
ignored = []
|
||||||
|
@ -345,10 +351,7 @@ class TTSDataset(Dataset):
|
||||||
ignored.append(idx)
|
ignored.append(idx)
|
||||||
else:
|
else:
|
||||||
new_items.append(self.items[idx])
|
new_items.append(self.items[idx])
|
||||||
|
|
||||||
# shuffle batch groups
|
# shuffle batch groups
|
||||||
# create batches with similar length items
|
|
||||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
|
||||||
if self.batch_group_size > 0:
|
if self.batch_group_size > 0:
|
||||||
for i in range(len(new_items) // self.batch_group_size):
|
for i in range(len(new_items) // self.batch_group_size):
|
||||||
offset = i * self.batch_group_size
|
offset = i * self.batch_group_size
|
||||||
|
@ -356,14 +359,8 @@ class TTSDataset(Dataset):
|
||||||
temp_items = new_items[offset:end_offset]
|
temp_items = new_items[offset:end_offset]
|
||||||
random.shuffle(temp_items)
|
random.shuffle(temp_items)
|
||||||
new_items[offset:end_offset] = temp_items
|
new_items[offset:end_offset] = temp_items
|
||||||
|
|
||||||
if len(new_items) == 0:
|
|
||||||
raise RuntimeError(" [!] No items left after filtering.")
|
|
||||||
|
|
||||||
# update items to the new sorted items
|
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
# logging
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
|
@ -413,9 +410,14 @@ class TTSDataset(Dataset):
|
||||||
# convert list of dicts to dict of lists
|
# convert list of dicts to dict of lists
|
||||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
|
|
||||||
|
# get language ids from language names
|
||||||
|
if self.language_id_mapping is not None:
|
||||||
|
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
|
||||||
|
else:
|
||||||
|
language_ids = None
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
if self.d_vector_mapping is not None:
|
if self.d_vector_mapping is not None:
|
||||||
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
wav_files_names = list(batch["wav_file_name"])
|
||||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
|
@ -466,6 +468,9 @@ class TTSDataset(Dataset):
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
|
||||||
|
if language_ids is not None:
|
||||||
|
language_ids = torch.LongTensor(language_ids)
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
|
@ -528,6 +533,7 @@ class TTSDataset(Dataset):
|
||||||
"waveform": wav_padded,
|
"waveform": wav_padded,
|
||||||
"raw_text": batch["raw_text"],
|
"raw_text": batch["raw_text"],
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
|
"language_ids": language_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -542,7 +548,6 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
class PitchExtractor:
|
class PitchExtractor:
|
||||||
"""Pitch Extractor for computing F0 from wav files.
|
"""Pitch Extractor for computing F0 from wav files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items (List[List]): Dataset samples.
|
items (List[List]): Dataset samples.
|
||||||
verbose (bool): Whether to print the progress.
|
verbose (bool): Whether to print the progress.
|
||||||
|
|
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
||||||
########################
|
########################
|
||||||
|
|
||||||
|
|
||||||
def tweb(root_path, meta_file):
|
def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalize TWEB dataset.
|
"""Normalize TWEB dataset.
|
||||||
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
|
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +28,7 @@ def tweb(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def mozilla(root_path, meta_file):
|
def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes Mozilla meta data files to TTS format"""
|
"""Normalizes Mozilla meta data files to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
@ -43,7 +43,7 @@ def mozilla(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def mozilla_de(root_path, meta_file):
|
def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes Mozilla meta data files to TTS format"""
|
"""Normalizes Mozilla meta data files to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
@ -59,7 +59,7 @@ def mozilla_de(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def mailabs(root_path, meta_files=None):
|
def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
||||||
"""Normalizes M-AI-Labs meta data files to TTS format
|
"""Normalizes M-AI-Labs meta data files to TTS format
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -68,25 +68,34 @@ def mailabs(root_path, meta_files=None):
|
||||||
recursively. Defaults to None
|
recursively. Defaults to None
|
||||||
"""
|
"""
|
||||||
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||||
if meta_files is None:
|
if not meta_files:
|
||||||
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
|
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
|
||||||
else:
|
else:
|
||||||
csv_files = meta_files
|
csv_files = meta_files
|
||||||
|
|
||||||
# meta_files = [f.strip() for f in meta_files.split(",")]
|
# meta_files = [f.strip() for f in meta_files.split(",")]
|
||||||
items = []
|
items = []
|
||||||
for csv_file in csv_files:
|
for csv_file in csv_files:
|
||||||
txt_file = os.path.join(root_path, csv_file)
|
if os.path.isfile(csv_file):
|
||||||
|
txt_file = csv_file
|
||||||
|
else:
|
||||||
|
txt_file = os.path.join(root_path, csv_file)
|
||||||
|
|
||||||
folder = os.path.dirname(txt_file)
|
folder = os.path.dirname(txt_file)
|
||||||
# determine speaker based on folder structure...
|
# determine speaker based on folder structure...
|
||||||
speaker_name_match = speaker_regex.search(txt_file)
|
speaker_name_match = speaker_regex.search(txt_file)
|
||||||
if speaker_name_match is None:
|
if speaker_name_match is None:
|
||||||
continue
|
continue
|
||||||
speaker_name = speaker_name_match.group("speaker_name")
|
speaker_name = speaker_name_match.group("speaker_name")
|
||||||
|
# ignore speakers
|
||||||
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_name in ignored_speakers:
|
||||||
|
continue
|
||||||
print(" | > {}".format(csv_file))
|
print(" | > {}".format(csv_file))
|
||||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
if meta_files is None:
|
if not meta_files:
|
||||||
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
|
||||||
else:
|
else:
|
||||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||||
|
@ -94,11 +103,12 @@ def mailabs(root_path, meta_files=None):
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append([text, wav_file, speaker_name])
|
items.append([text, wav_file, speaker_name])
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("> File %s does not exist!" % (wav_file))
|
# M-AI-Labs have some missing samples, so just print the warning
|
||||||
|
print("> File %s does not exist!" % (wav_file))
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def ljspeech(root_path, meta_file):
|
def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the LJSpeech meta data file to TTS format
|
"""Normalizes the LJSpeech meta data file to TTS format
|
||||||
https://keithito.com/LJ-Speech-Dataset/"""
|
https://keithito.com/LJ-Speech-Dataset/"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
|
@ -113,7 +123,7 @@ def ljspeech(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def ljspeech_test(root_path, meta_file):
|
def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the LJSpeech meta data file for TTS testing
|
"""Normalizes the LJSpeech meta data file for TTS testing
|
||||||
https://keithito.com/LJ-Speech-Dataset/"""
|
https://keithito.com/LJ-Speech-Dataset/"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
|
@ -127,7 +137,7 @@ def ljspeech_test(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def sam_accenture(root_path, meta_file):
|
def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the sam-accenture meta data file to TTS format
|
"""Normalizes the sam-accenture meta data file to TTS format
|
||||||
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
||||||
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
|
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
|
||||||
|
@ -144,12 +154,12 @@ def sam_accenture(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def ruslan(root_path, meta_file):
|
def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the RUSLAN meta data file to TTS format
|
"""Normalizes the RUSLAN meta data file to TTS format
|
||||||
https://ruslan-corpus.github.io/"""
|
https://ruslan-corpus.github.io/"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
speaker_name = "ljspeech"
|
speaker_name = "ruslan"
|
||||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
|
@ -159,11 +169,11 @@ def ruslan(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def css10(root_path, meta_file):
|
def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the CSS10 dataset file to TTS format"""
|
"""Normalizes the CSS10 dataset file to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
speaker_name = "ljspeech"
|
speaker_name = "css10"
|
||||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
|
@ -173,7 +183,7 @@ def css10(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def nancy(root_path, meta_file):
|
def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Normalizes the Nancy meta data file to TTS format"""
|
"""Normalizes the Nancy meta data file to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
@ -187,7 +197,7 @@ def nancy(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def common_voice(root_path, meta_file):
|
def common_voice(root_path, meta_file, ignored_speakers=None):
|
||||||
"""Normalize the common voice meta data file to TTS format."""
|
"""Normalize the common voice meta data file to TTS format."""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
@ -198,15 +208,19 @@ def common_voice(root_path, meta_file):
|
||||||
cols = line.split("\t")
|
cols = line.split("\t")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
speaker_name = cols[0]
|
speaker_name = cols[0]
|
||||||
|
# ignore speakers
|
||||||
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_name in ignored_speakers:
|
||||||
|
continue
|
||||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||||
items.append([text, wav_file, "MCV_" + speaker_name])
|
items.append([text, wav_file, "MCV_" + speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def libri_tts(root_path, meta_files=None):
|
def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
||||||
"""https://ai.google/tools/datasets/libri-tts/"""
|
"""https://ai.google/tools/datasets/libri-tts/"""
|
||||||
items = []
|
items = []
|
||||||
if meta_files is None:
|
if not meta_files:
|
||||||
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
||||||
else:
|
else:
|
||||||
if isinstance(meta_files, str):
|
if isinstance(meta_files, str):
|
||||||
|
@ -222,13 +236,17 @@ def libri_tts(root_path, meta_files=None):
|
||||||
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
||||||
wav_file = os.path.join(_root_path, file_name + ".wav")
|
wav_file = os.path.join(_root_path, file_name + ".wav")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
|
# ignore speakers
|
||||||
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_name in ignored_speakers:
|
||||||
|
continue
|
||||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||||
for item in items:
|
for item in items:
|
||||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def custom_turkish(root_path, meta_file):
|
def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
speaker_name = "turkish-female"
|
speaker_name = "turkish-female"
|
||||||
|
@ -247,7 +265,7 @@ def custom_turkish(root_path, meta_file):
|
||||||
|
|
||||||
|
|
||||||
# ToDo: add the dataset link when the dataset is released publicly
|
# ToDo: add the dataset link when the dataset is released publicly
|
||||||
def brspeech(root_path, meta_file):
|
def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||||
"""BRSpeech 3.0 beta"""
|
"""BRSpeech 3.0 beta"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
@ -258,21 +276,25 @@ def brspeech(root_path, meta_file):
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, cols[0])
|
wav_file = os.path.join(root_path, cols[0])
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
speaker_name = cols[3]
|
speaker_id = cols[3]
|
||||||
items.append([text, wav_file, speaker_name])
|
# ignore speakers
|
||||||
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_id in ignored_speakers:
|
||||||
|
continue
|
||||||
|
items.append([text, wav_file, speaker_id])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
||||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||||
test_speakers = meta_files
|
|
||||||
items = []
|
items = []
|
||||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||||
for meta_file in meta_files:
|
for meta_file in meta_files:
|
||||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||||
file_id = txt_file.split(".")[0]
|
file_id = txt_file.split(".")[0]
|
||||||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
# ignore speakers
|
||||||
if speaker_id in test_speakers:
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||||
text = file_text.readlines()[0]
|
text = file_text.readlines()[0]
|
||||||
|
@ -282,15 +304,16 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
|
||||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||||
items = []
|
items = []
|
||||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||||
for text_file in txt_files:
|
for text_file in txt_files:
|
||||||
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
|
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
|
||||||
file_id = txt_file.split(".")[0]
|
file_id = txt_file.split(".")[0]
|
||||||
if isinstance(meta_files, list): # if is list ignore this speakers ids
|
# ignore speakers
|
||||||
if speaker_id in meta_files:
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||||
items.append([None, wav_file, "VCTK_" + speaker_id])
|
items.append([None, wav_file, "VCTK_" + speaker_id])
|
||||||
|
@ -298,7 +321,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def mls(root_path, meta_files=None):
|
def mls(root_path, meta_files=None, ignored_speakers=None):
|
||||||
"""http://www.openslr.org/94/"""
|
"""http://www.openslr.org/94/"""
|
||||||
items = []
|
items = []
|
||||||
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
|
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
|
||||||
|
@ -307,19 +330,23 @@ def mls(root_path, meta_files=None):
|
||||||
text = text[:-1]
|
text = text[:-1]
|
||||||
speaker, book, *_ = file.split("_")
|
speaker, book, *_ = file.split("_")
|
||||||
wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav")
|
wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav")
|
||||||
|
# ignore speakers
|
||||||
|
if isinstance(ignored_speakers, list):
|
||||||
|
if speaker in ignored_speakers:
|
||||||
|
continue
|
||||||
items.append([text, wav_file, "MLS_" + speaker])
|
items.append([text, wav_file, "MLS_" + speaker])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
# ======================================== VOX CELEB ===========================================
|
# ======================================== VOX CELEB ===========================================
|
||||||
def voxceleb2(root_path, meta_file=None):
|
def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
:param meta_file Used only for consistency with load_tts_samples api
|
:param meta_file Used only for consistency with load_tts_samples api
|
||||||
"""
|
"""
|
||||||
return _voxcel_x(root_path, meta_file, voxcel_idx="2")
|
return _voxcel_x(root_path, meta_file, voxcel_idx="2")
|
||||||
|
|
||||||
|
|
||||||
def voxceleb1(root_path, meta_file=None):
|
def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
:param meta_file Used only for consistency with load_tts_samples api
|
:param meta_file Used only for consistency with load_tts_samples api
|
||||||
"""
|
"""
|
||||||
|
@ -361,7 +388,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
||||||
return [x.strip().split("|") for x in f.readlines()]
|
return [x.strip().split("|") for x in f.readlines()]
|
||||||
|
|
||||||
|
|
||||||
def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
|
||||||
"""Normalizes the Baker meta data file to TTS format
|
"""Normalizes the Baker meta data file to TTS format
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -381,7 +408,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def kokoro(root_path, meta_file):
|
def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
|
"""Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
|
|
@ -18,8 +18,13 @@ class DurationPredictor(nn.Module):
|
||||||
dropout_p (float): Dropout rate used after each conv layer.
|
dropout_p (float): Dropout rate used after each conv layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# add language embedding dim in the input
|
||||||
|
if language_emb_dim:
|
||||||
|
in_channels += language_emb_dim
|
||||||
|
|
||||||
# class arguments
|
# class arguments
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.filter_channels = hidden_channels
|
self.filter_channels = hidden_channels
|
||||||
|
@ -36,7 +41,10 @@ class DurationPredictor(nn.Module):
|
||||||
if cond_channels is not None and cond_channels != 0:
|
if cond_channels is not None and cond_channels != 0:
|
||||||
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None):
|
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||||
|
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, lang_emb=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
|
@ -45,6 +53,10 @@ class DurationPredictor(nn.Module):
|
||||||
"""
|
"""
|
||||||
if g is not None:
|
if g is not None:
|
||||||
x = x + self.cond(g)
|
x = x + self.cond(g)
|
||||||
|
|
||||||
|
if lang_emb is not None:
|
||||||
|
x = x + self.cond_lang(lang_emb)
|
||||||
|
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(x * x_mask)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.norm_1(x)
|
x = self.norm_1(x)
|
||||||
|
|
|
@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.feat_loss_alpha = c.feat_loss_alpha
|
self.feat_loss_alpha = c.feat_loss_alpha
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.mel_loss_alpha = c.mel_loss_alpha
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
|
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
c.audio.hop_length,
|
c.audio.hop_length,
|
||||||
|
@ -585,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
l = kl / torch.sum(z_mask)
|
l = kl / torch.sum(z_mask)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||||
|
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||||
|
return l
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
waveform,
|
waveform,
|
||||||
|
@ -598,6 +604,9 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_fake,
|
feats_disc_fake,
|
||||||
feats_disc_real,
|
feats_disc_real,
|
||||||
loss_duration,
|
loss_duration,
|
||||||
|
use_speaker_encoder_as_loss=False,
|
||||||
|
gt_spk_emb=None,
|
||||||
|
syn_spk_emb=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -618,13 +627,20 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
# compute mel spectrograms from the waveforms
|
# compute mel spectrograms from the waveforms
|
||||||
mel = self.stft(waveform)
|
mel = self.stft(waveform)
|
||||||
mel_hat = self.stft(waveform_hat)
|
mel_hat = self.stft(waveform_hat)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
|
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
|
||||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
|
||||||
|
if use_speaker_encoder_as_loss:
|
||||||
|
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||||
|
loss += loss_se
|
||||||
|
return_dict["loss_spk_encoder"] = loss_se
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
|
|
@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
|
language_emb_dim: int = None,
|
||||||
):
|
):
|
||||||
"""Text Encoder for VITS model.
|
"""Text Encoder for VITS model.
|
||||||
|
|
||||||
|
@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
|
||||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||||
|
|
||||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||||
|
|
||||||
|
if language_emb_dim:
|
||||||
|
hidden_channels += language_emb_dim
|
||||||
|
|
||||||
self.encoder = RelativePositionTransformer(
|
self.encoder = RelativePositionTransformer(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
|
@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
|
||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths):
|
def forward(self, x, x_lengths, lang_emb=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, T]`
|
- x: :math:`[B, T]`
|
||||||
- x_length: :math:`[B]`
|
- x_length: :math:`[B]`
|
||||||
"""
|
"""
|
||||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||||
|
|
||||||
|
# concat the lang emb in embedding chars
|
||||||
|
if lang_emb is not None:
|
||||||
|
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
|
||||||
|
|
|
@ -178,10 +178,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dropout_p: float,
|
||||||
|
num_flows=4,
|
||||||
|
cond_channels=0,
|
||||||
|
language_emb_dim=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# add language embedding dim in the input
|
||||||
|
if language_emb_dim:
|
||||||
|
in_channels += language_emb_dim
|
||||||
|
|
||||||
# condition encoder text
|
# condition encoder text
|
||||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
@ -205,7 +216,10 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
if cond_channels != 0 and cond_channels is not None:
|
if cond_channels != 0 and cond_channels is not None:
|
||||||
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||||
|
self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
|
@ -217,6 +231,10 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
x = self.pre(x)
|
x = self.pre(x)
|
||||||
if g is not None:
|
if g is not None:
|
||||||
x = x + self.cond(g)
|
x = x + self.cond(g)
|
||||||
|
|
||||||
|
if lang_emb is not None:
|
||||||
|
x = x + self.cond_lang(lang_emb)
|
||||||
|
|
||||||
x = self.convs(x, x_mask)
|
x = self.convs(x, x_mask)
|
||||||
x = self.proj(x) * x_mask
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
|
|
||||||
def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
# fetch the right model implementation.
|
# fetch the right model implementation.
|
||||||
if "base_model" in config and config["base_model"] is not None:
|
if "base_model" in config and config["base_model"] is not None:
|
||||||
|
@ -31,7 +31,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
||||||
config.model_params.num_chars = num_chars
|
config.model_params.num_chars = num_chars
|
||||||
if "model_args" in config:
|
if "model_args" in config:
|
||||||
config.model_args.num_chars = num_chars
|
config.model_args.num_chars = num_chars
|
||||||
model = MyModel(config, speaker_manager=speaker_manager)
|
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||||
|
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||||
|
else:
|
||||||
|
model = MyModel(config, speaker_manager=speaker_manager)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
|
||||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
return get_speaker_manager(config, restore_path, data, out_path)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
`in_channels` size of the connected layers.
|
||||||
|
|
||||||
|
This implementation yields 3 possible outcomes:
|
||||||
|
|
||||||
|
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||||
|
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||||
|
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||||
|
`config.d_vector_dim` or 512.
|
||||||
|
|
||||||
|
You can override this function for new models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
|
@ -97,6 +107,57 @@ class BaseTTS(BaseModel):
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
|
def get_aux_input(self, **kwargs) -> Dict:
|
||||||
|
"""Prepare and return `aux_input` used by `forward()`"""
|
||||||
|
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||||
|
|
||||||
|
def get_aux_input_from_test_setences(self, sentence_info):
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
config = self.config.model_args
|
||||||
|
else:
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
# extract speaker and language info
|
||||||
|
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||||
|
|
||||||
|
if isinstance(sentence_info, list):
|
||||||
|
if len(sentence_info) == 1:
|
||||||
|
text = sentence_info[0]
|
||||||
|
elif len(sentence_info) == 2:
|
||||||
|
text, speaker_name = sentence_info
|
||||||
|
elif len(sentence_info) == 3:
|
||||||
|
text, speaker_name, style_wav = sentence_info
|
||||||
|
elif len(sentence_info) == 4:
|
||||||
|
text, speaker_name, style_wav, language_name = sentence_info
|
||||||
|
else:
|
||||||
|
text = sentence_info
|
||||||
|
|
||||||
|
# get speaker id/d_vector
|
||||||
|
speaker_id, d_vector, language_id = None, None, None
|
||||||
|
if hasattr(self, "speaker_manager"):
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
if speaker_name is None:
|
||||||
|
d_vector = self.speaker_manager.get_random_d_vector()
|
||||||
|
else:
|
||||||
|
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
|
||||||
|
elif config.use_speaker_embedding:
|
||||||
|
if speaker_name is None:
|
||||||
|
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||||
|
else:
|
||||||
|
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||||
|
|
||||||
|
# get language id
|
||||||
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
|
language_id = self.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"language_id": language_id,
|
||||||
|
}
|
||||||
|
|
||||||
def format_batch(self, batch: Dict) -> Dict:
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
"""Generic batch formatting for `TTSDataset`.
|
"""Generic batch formatting for `TTSDataset`.
|
||||||
|
|
||||||
|
@ -122,6 +183,7 @@ class BaseTTS(BaseModel):
|
||||||
attn_mask = batch["attns"]
|
attn_mask = batch["attns"]
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
pitch = batch["pitch"]
|
pitch = batch["pitch"]
|
||||||
|
language_ids = batch["language_ids"]
|
||||||
max_text_length = torch.max(text_lengths.float())
|
max_text_length = torch.max(text_lengths.float())
|
||||||
max_spec_length = torch.max(mel_lengths.float())
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -169,6 +231,7 @@ class BaseTTS(BaseModel):
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
"waveform": waveform,
|
"waveform": waveform,
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
|
"language_ids": language_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
|
@ -188,8 +251,15 @@ class BaseTTS(BaseModel):
|
||||||
|
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
if hasattr(config, "model_args"):
|
||||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
speaker_id_mapping = (
|
||||||
|
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||||
|
)
|
||||||
|
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||||
|
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||||
|
else:
|
||||||
|
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||||
|
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||||
else:
|
else:
|
||||||
speaker_id_mapping = None
|
speaker_id_mapping = None
|
||||||
d_vector_mapping = None
|
d_vector_mapping = None
|
||||||
|
@ -199,7 +269,14 @@ class BaseTTS(BaseModel):
|
||||||
if hasattr(self, "make_symbols"):
|
if hasattr(self, "make_symbols"):
|
||||||
custom_symbols = self.make_symbols(self.config)
|
custom_symbols = self.make_symbols(self.config)
|
||||||
|
|
||||||
# init dataset
|
if hasattr(self, "language_manager"):
|
||||||
|
language_id_mapping = (
|
||||||
|
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
language_id_mapping = None
|
||||||
|
|
||||||
|
# init dataloader
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
text_cleaner=config.text_cleaner,
|
||||||
|
@ -222,7 +299,8 @@ class BaseTTS(BaseModel):
|
||||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
d_vector_mapping=d_vector_mapping,
|
||||||
|
language_id_mapping=language_id_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pre-compute phonemes
|
# pre-compute phonemes
|
||||||
|
@ -268,7 +346,22 @@ class BaseTTS(BaseModel):
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
# init dataloader
|
# Weighted samplers
|
||||||
|
assert not (
|
||||||
|
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||||
|
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
assert not (
|
||||||
|
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||||
|
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
|
||||||
|
if sampler is None:
|
||||||
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
|
print(" > Using Language weighted sampler")
|
||||||
|
sampler = get_language_weighted_sampler(dataset.items)
|
||||||
|
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||||
|
print(" > Using Language weighted sampler")
|
||||||
|
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
@ -340,8 +433,7 @@ class BaseTTS(BaseModel):
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
||||||
def on_init_start(self, trainer):
|
def on_init_start(self, trainer):
|
||||||
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
|
||||||
speakers.json file path."""
|
|
||||||
if self.speaker_manager is not None:
|
if self.speaker_manager is not None:
|
||||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||||
|
@ -352,3 +444,13 @@ class BaseTTS(BaseModel):
|
||||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
print(f" > `speakers.json` is saved to {output_path}.")
|
print(f" > `speakers.json` is saved to {output_path}.")
|
||||||
print(" > `speakers_file` is updated in the config.json.")
|
print(" > `speakers_file` is updated in the config.json.")
|
||||||
|
|
||||||
|
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||||
|
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||||
|
self.language_manager.save_language_ids_to_file(output_path)
|
||||||
|
trainer.config.language_ids_file = output_path
|
||||||
|
if hasattr(trainer.config, "model_args"):
|
||||||
|
trainer.config.model_args.language_ids_file = output_path
|
||||||
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
|
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||||
|
print(" > `language_ids_file` is updated in the config.json.")
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import math
|
import math
|
||||||
import random
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
|
@ -15,6 +17,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
@ -138,11 +141,50 @@ class VitsArgs(Coqpit):
|
||||||
use_d_vector_file (bool):
|
use_d_vector_file (bool):
|
||||||
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_file (str):
|
||||||
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
d_vector_dim (int):
|
d_vector_dim (int):
|
||||||
Number of d-vector channels. Defaults to 0.
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
detach_dp_input (bool):
|
detach_dp_input (bool):
|
||||||
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
Enable/Disable language embedding for multilingual models. Defaults to False.
|
||||||
|
|
||||||
|
embedded_language_dim (int):
|
||||||
|
Number of language embedding channels. Defaults to 4.
|
||||||
|
|
||||||
|
num_languages (int):
|
||||||
|
Number of languages for the language embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language mapping file for the Language Manager. Defaults to None.
|
||||||
|
|
||||||
|
use_speaker_encoder_as_loss (bool):
|
||||||
|
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||||
|
|
||||||
|
speaker_encoder_config_path (str):
|
||||||
|
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
|
speaker_encoder_model_path (str):
|
||||||
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
|
freeze_encoder (bool):
|
||||||
|
Freeze the encoder weigths during training. Defaults to False.
|
||||||
|
|
||||||
|
freeze_DP (bool):
|
||||||
|
Freeze the duration predictor weigths during training. Defaults to False.
|
||||||
|
|
||||||
|
freeze_PE (bool):
|
||||||
|
Freeze the posterior encoder weigths during training. Defaults to False.
|
||||||
|
|
||||||
|
freeze_flow_encoder (bool):
|
||||||
|
Freeze the flow encoder weigths during training. Defaults to False.
|
||||||
|
|
||||||
|
freeze_waveform_decoder (bool):
|
||||||
|
Freeze the waveform decoder weigths during training. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_chars: int = 100
|
num_chars: int = 100
|
||||||
|
@ -179,11 +221,23 @@ class VitsArgs(Coqpit):
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
num_speakers: int = 0
|
num_speakers: int = 0
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
|
d_vector_file: str = None
|
||||||
speaker_embedding_channels: int = 256
|
speaker_embedding_channels: int = 256
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
d_vector_file: str = None
|
|
||||||
d_vector_dim: int = 0
|
d_vector_dim: int = 0
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
|
use_language_embedding: bool = False
|
||||||
|
embedded_language_dim: int = 4
|
||||||
|
num_languages: int = 0
|
||||||
|
language_ids_file: str = None
|
||||||
|
use_speaker_encoder_as_loss: bool = False
|
||||||
|
speaker_encoder_config_path: str = ""
|
||||||
|
speaker_encoder_model_path: str = ""
|
||||||
|
freeze_encoder: bool = False
|
||||||
|
freeze_DP: bool = False
|
||||||
|
freeze_PE: bool = False
|
||||||
|
freeze_flow_decoder: bool = False
|
||||||
|
freeze_waveform_decoder: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Vits(BaseTTS):
|
class Vits(BaseTTS):
|
||||||
|
@ -216,13 +270,18 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
language_manager: LanguageManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.END2END = True
|
self.END2END = True
|
||||||
|
|
||||||
self.speaker_manager = speaker_manager
|
self.speaker_manager = speaker_manager
|
||||||
|
self.language_manager = language_manager
|
||||||
if config.__class__.__name__ == "VitsConfig":
|
if config.__class__.__name__ == "VitsConfig":
|
||||||
# loading from VitsConfig
|
# loading from VitsConfig
|
||||||
if "num_chars" not in config:
|
if "num_chars" not in config:
|
||||||
|
@ -242,6 +301,7 @@ class Vits(BaseTTS):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
self.init_multilingual(config)
|
||||||
|
|
||||||
self.length_scale = args.length_scale
|
self.length_scale = args.length_scale
|
||||||
self.noise_scale = args.noise_scale
|
self.noise_scale = args.noise_scale
|
||||||
|
@ -260,6 +320,7 @@ class Vits(BaseTTS):
|
||||||
args.num_layers_text_encoder,
|
args.num_layers_text_encoder,
|
||||||
args.kernel_size_text_encoder,
|
args.kernel_size_text_encoder,
|
||||||
args.dropout_p_text_encoder,
|
args.dropout_p_text_encoder,
|
||||||
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -289,10 +350,16 @@ class Vits(BaseTTS):
|
||||||
args.dropout_p_duration_predictor,
|
args.dropout_p_duration_predictor,
|
||||||
4,
|
4,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
|
args.hidden_channels,
|
||||||
|
256,
|
||||||
|
3,
|
||||||
|
args.dropout_p_duration_predictor,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
|
@ -318,54 +385,150 @@ class Vits(BaseTTS):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.embedded_speaker_dim = 0
|
self.embedded_speaker_dim = 0
|
||||||
if hasattr(config, "model_args"):
|
self.num_speakers = self.args.num_speakers
|
||||||
config = config.model_args
|
|
||||||
|
|
||||||
self.num_speakers = config.num_speakers
|
if self.speaker_manager:
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
|
||||||
if config.use_speaker_embedding:
|
if self.args.use_speaker_embedding:
|
||||||
self._init_speaker_embedding(config)
|
self._init_speaker_embedding()
|
||||||
|
|
||||||
if config.use_d_vector_file:
|
if self.args.use_d_vector_file:
|
||||||
self._init_d_vector(config)
|
self._init_d_vector()
|
||||||
|
|
||||||
def _init_speaker_embedding(self, config):
|
# TODO: make this a function
|
||||||
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
|
if self.speaker_manager.speaker_encoder is None and (
|
||||||
|
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.speaker_manager.speaker_encoder.eval()
|
||||||
|
print(" > External Speaker Encoder Loaded !!")
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||||
|
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||||
|
):
|
||||||
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=self.audio_config["sample_rate"],
|
||||||
|
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.audio_transform = None
|
||||||
|
|
||||||
|
def _init_speaker_embedding(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if config.speakers_file is not None:
|
|
||||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
|
|
||||||
|
|
||||||
if self.num_speakers > 0:
|
if self.num_speakers > 0:
|
||||||
print(" > initialization of speaker-embedding layers.")
|
print(" > initialization of speaker-embedding layers.")
|
||||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
|
||||||
def _init_d_vector(self, config):
|
def _init_d_vector(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
|
||||||
|
def init_multilingual(self, config: Coqpit):
|
||||||
|
"""Initialize multilingual modules of a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
"""
|
||||||
|
if self.args.language_ids_file is not None:
|
||||||
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
|
|
||||||
|
if self.args.use_language_embedding and self.language_manager:
|
||||||
|
print(" > initialization of language-embedding layers.")
|
||||||
|
self.num_languages = self.language_manager.num_languages
|
||||||
|
self.embedded_language_dim = self.args.embedded_language_dim
|
||||||
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||||
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||||
|
else:
|
||||||
|
self.embedded_language_dim = 0
|
||||||
|
self.emb_l = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
sid, g = None, None
|
sid, g, lid = None, None, None
|
||||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
sid = aux_input["speaker_ids"]
|
sid = aux_input["speaker_ids"]
|
||||||
if sid.ndim == 0:
|
if sid.ndim == 0:
|
||||||
sid = sid.unsqueeze_(0)
|
sid = sid.unsqueeze_(0)
|
||||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||||
g = aux_input["d_vectors"]
|
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||||
return sid, g
|
if g.ndim == 2:
|
||||||
|
g = g.unsqueeze_(0)
|
||||||
|
|
||||||
|
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
||||||
|
lid = aux_input["language_ids"]
|
||||||
|
if lid.ndim == 0:
|
||||||
|
lid = lid.unsqueeze_(0)
|
||||||
|
|
||||||
|
return sid, g, lid
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
|
||||||
|
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
config = self.config.model_args
|
||||||
|
else:
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
# extract speaker and language info
|
||||||
|
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||||
|
|
||||||
|
if isinstance(sentence_info, list):
|
||||||
|
if len(sentence_info) == 1:
|
||||||
|
text = sentence_info[0]
|
||||||
|
elif len(sentence_info) == 2:
|
||||||
|
text, speaker_name = sentence_info
|
||||||
|
elif len(sentence_info) == 3:
|
||||||
|
text, speaker_name, style_wav = sentence_info
|
||||||
|
elif len(sentence_info) == 4:
|
||||||
|
text, speaker_name, style_wav, language_name = sentence_info
|
||||||
|
else:
|
||||||
|
text = sentence_info
|
||||||
|
|
||||||
|
# get speaker id/d_vector
|
||||||
|
speaker_id, d_vector, language_id = None, None, None
|
||||||
|
if hasattr(self, "speaker_manager"):
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
if speaker_name is None:
|
||||||
|
d_vector = self.speaker_manager.get_random_d_vector()
|
||||||
|
else:
|
||||||
|
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False)
|
||||||
|
elif config.use_speaker_embedding:
|
||||||
|
if speaker_name is None:
|
||||||
|
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||||
|
else:
|
||||||
|
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||||
|
|
||||||
|
# get language id
|
||||||
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
|
language_id = self.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"language_id": language_id,
|
||||||
|
"language_name": language_name,
|
||||||
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -373,7 +536,8 @@ class Vits(BaseTTS):
|
||||||
x_lengths: torch.tensor,
|
x_lengths: torch.tensor,
|
||||||
y: torch.tensor,
|
y: torch.tensor,
|
||||||
y_lengths: torch.tensor,
|
y_lengths: torch.tensor,
|
||||||
aux_input={"d_vectors": None, "speaker_ids": None},
|
waveform: torch.tensor,
|
||||||
|
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Forward pass of the model.
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
@ -382,7 +546,9 @@ class Vits(BaseTTS):
|
||||||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||||
y (torch.tensor): Batch of input spectrograms.
|
y (torch.tensor): Batch of input spectrograms.
|
||||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
waveform (torch.tensor): Batch of ground truth waveforms per sample.
|
||||||
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
||||||
|
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: model outputs keyed by the output name.
|
Dict: model outputs keyed by the output name.
|
||||||
|
@ -392,17 +558,24 @@ class Vits(BaseTTS):
|
||||||
- x_lengths: :math:`[B]`
|
- x_lengths: :math:`[B]`
|
||||||
- y: :math:`[B, C, T_spec]`
|
- y: :math:`[B, C, T_spec]`
|
||||||
- y_lengths: :math:`[B]`
|
- y_lengths: :math:`[B]`
|
||||||
|
- waveform: :math:`[B, T_wav, 1]`
|
||||||
- d_vectors: :math:`[B, C, 1]`
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
- speaker_ids: :math:`[B]`
|
- speaker_ids: :math:`[B]`
|
||||||
|
- language_ids: :math:`[B]`
|
||||||
"""
|
"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
sid, g = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1 and sid is not None:
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
|
# language embedding
|
||||||
|
lang_emb = None
|
||||||
|
if self.args.use_language_embedding and lid is not None:
|
||||||
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
@ -428,6 +601,7 @@ class Vits(BaseTTS):
|
||||||
x_mask,
|
x_mask,
|
||||||
attn_durations,
|
attn_durations,
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = loss_duration / torch.sum(x_mask)
|
loss_duration = loss_duration / torch.sum(x_mask)
|
||||||
else:
|
else:
|
||||||
|
@ -436,6 +610,7 @@ class Vits(BaseTTS):
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
x_mask,
|
x_mask,
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
outputs["loss_duration"] = loss_duration
|
outputs["loss_duration"] = loss_duration
|
||||||
|
@ -447,40 +622,73 @@ class Vits(BaseTTS):
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
|
wav_seg = segment(
|
||||||
|
waveform,
|
||||||
|
slice_ids * self.config.audio.hop_length,
|
||||||
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||||
|
# concate generated and GT waveforms
|
||||||
|
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||||
|
|
||||||
|
# resample audio to speaker encoder sample_rate
|
||||||
|
# pylint: disable=W0105
|
||||||
|
if self.audio_transform is not None:
|
||||||
|
wavs_batch = self.audio_transform(wavs_batch)
|
||||||
|
|
||||||
|
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||||
|
|
||||||
|
# split generated and GT speaker embeddings
|
||||||
|
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||||
|
else:
|
||||||
|
gt_spk_emb, syn_spk_emb = None, None
|
||||||
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
"slice_ids": slice_ids,
|
|
||||||
"z": z,
|
"z": z,
|
||||||
"z_p": z_p,
|
"z_p": z_p,
|
||||||
"m_p": m_p,
|
"m_p": m_p,
|
||||||
"logs_p": logs_p,
|
"logs_p": logs_p,
|
||||||
"m_q": m_q,
|
"m_q": m_q,
|
||||||
"logs_q": logs_q,
|
"logs_q": logs_q,
|
||||||
|
"waveform_seg": wav_seg,
|
||||||
|
"gt_spk_emb": gt_spk_emb,
|
||||||
|
"syn_spk_emb": syn_spk_emb,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, T_seq]`
|
- x: :math:`[B, T_seq]`
|
||||||
- d_vectors: :math:`[B, C, 1]`
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
- speaker_ids: :math:`[B]`
|
- speaker_ids: :math:`[B]`
|
||||||
"""
|
"""
|
||||||
sid, g = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
# speaker embedding
|
||||||
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
if self.num_speakers > 0 and sid is not None:
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
|
|
||||||
|
# language embedding
|
||||||
|
lang_emb = None
|
||||||
|
if self.args.use_language_embedding and lid is not None:
|
||||||
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
logw = self.duration_predictor(
|
||||||
|
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(x, x_mask, g=g)
|
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
|
@ -499,12 +707,30 @@ class Vits(BaseTTS):
|
||||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||||
"""TODO: create an end-point for voice conversion"""
|
"""Forward pass for voice conversion
|
||||||
|
|
||||||
|
TODO: create an end-point for voice conversion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
|
||||||
|
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
|
||||||
|
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
|
||||||
|
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
|
||||||
|
"""
|
||||||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
|
||||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
# speaker embedding
|
||||||
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||||
|
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
||||||
|
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
||||||
|
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
||||||
|
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
||||||
|
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||||
|
|
||||||
|
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
|
||||||
z_p = self.flow(z, y_mask, g=g_src)
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
|
@ -525,6 +751,30 @@ class Vits(BaseTTS):
|
||||||
if optimizer_idx not in [0, 1]:
|
if optimizer_idx not in [0, 1]:
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
|
if self.args.freeze_encoder:
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if hasattr(self, "emb_l"):
|
||||||
|
for param in self.emb_l.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if self.args.freeze_PE:
|
||||||
|
for param in self.posterior_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if self.args.freeze_DP:
|
||||||
|
for param in self.duration_predictor.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if self.args.freeze_flow_decoder:
|
||||||
|
for param in self.flow.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if self.args.freeze_waveform_decoder:
|
||||||
|
for param in self.waveform_decoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -532,6 +782,7 @@ class Vits(BaseTTS):
|
||||||
linear_input = batch["linear_input"]
|
linear_input = batch["linear_input"]
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
language_ids = batch["language_ids"]
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
|
@ -540,31 +791,26 @@ class Vits(BaseTTS):
|
||||||
text_lengths,
|
text_lengths,
|
||||||
linear_input.transpose(1, 2),
|
linear_input.transpose(1, 2),
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
waveform.transpose(1, 2),
|
||||||
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||||
)
|
)
|
||||||
|
|
||||||
# cache tensors for the discriminator
|
# cache tensors for the discriminator
|
||||||
self.y_disc_cache = None
|
self.y_disc_cache = None
|
||||||
self.wav_seg_disc_cache = None
|
self.wav_seg_disc_cache = None
|
||||||
self.y_disc_cache = outputs["model_outputs"]
|
self.y_disc_cache = outputs["model_outputs"]
|
||||||
wav_seg = segment(
|
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
||||||
waveform.transpose(1, 2),
|
|
||||||
outputs["slice_ids"] * self.config.audio.hop_length,
|
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
|
||||||
)
|
|
||||||
self.wav_seg_disc_cache = wav_seg
|
|
||||||
outputs["waveform_seg"] = wav_seg
|
|
||||||
|
|
||||||
# compute discriminator scores and features
|
# compute discriminator scores and features
|
||||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
||||||
outputs["model_outputs"], wav_seg
|
outputs["model_outputs"], outputs["waveform_seg"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
loss_dict = criterion[optimizer_idx](
|
loss_dict = criterion[optimizer_idx](
|
||||||
waveform_hat=outputs["model_outputs"].float(),
|
waveform_hat=outputs["model_outputs"].float(),
|
||||||
waveform=wav_seg.float(),
|
waveform=outputs["waveform_seg"].float(),
|
||||||
z_p=outputs["z_p"].float(),
|
z_p=outputs["z_p"].float(),
|
||||||
logs_q=outputs["logs_q"].float(),
|
logs_q=outputs["logs_q"].float(),
|
||||||
m_p=outputs["m_p"].float(),
|
m_p=outputs["m_p"].float(),
|
||||||
|
@ -574,6 +820,9 @@ class Vits(BaseTTS):
|
||||||
feats_disc_fake=outputs["feats_disc_fake"],
|
feats_disc_fake=outputs["feats_disc_fake"],
|
||||||
feats_disc_real=outputs["feats_disc_real"],
|
feats_disc_real=outputs["feats_disc_real"],
|
||||||
loss_duration=outputs["loss_duration"],
|
loss_duration=outputs["loss_duration"],
|
||||||
|
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||||
|
gt_spk_emb=outputs["gt_spk_emb"],
|
||||||
|
syn_spk_emb=outputs["syn_spk_emb"],
|
||||||
)
|
)
|
||||||
|
|
||||||
elif optimizer_idx == 1:
|
elif optimizer_idx == 1:
|
||||||
|
@ -651,32 +900,28 @@ class Vits(BaseTTS):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = {
|
for idx, s_info in enumerate(test_sentences):
|
||||||
"speaker_id": None
|
try:
|
||||||
if not self.config.use_speaker_embedding
|
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
||||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
wav, alignment, _, _ = synthesis(
|
||||||
"d_vector": None
|
self,
|
||||||
if not self.config.use_d_vector_file
|
aux_inputs["text"],
|
||||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
self.config,
|
||||||
"style_wav": None,
|
"cuda" in str(next(self.parameters()).device),
|
||||||
}
|
ap,
|
||||||
for idx, sen in enumerate(test_sentences):
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
wav, alignment, _, _ = synthesis(
|
d_vector=aux_inputs["d_vector"],
|
||||||
self,
|
style_wav=aux_inputs["style_wav"],
|
||||||
sen,
|
language_id=aux_inputs["language_id"],
|
||||||
self.config,
|
language_name=aux_inputs["language_name"],
|
||||||
"cuda" in str(next(self.parameters()).device),
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
ap,
|
use_griffin_lim=True,
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
do_trim_silence=False,
|
||||||
d_vector=aux_inputs["d_vector"],
|
).values()
|
||||||
style_wav=aux_inputs["style_wav"],
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||||
use_griffin_lim=True,
|
except: # pylint: disable=bare-except
|
||||||
do_trim_silence=False,
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
).values()
|
|
||||||
|
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
|
@ -695,8 +940,12 @@ class Vits(BaseTTS):
|
||||||
self.waveform_decoder.parameters(),
|
self.waveform_decoder.parameters(),
|
||||||
)
|
)
|
||||||
# add the speaker embedding layer
|
# add the speaker embedding layer
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||||
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
|
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
|
||||||
|
# add the language embedding layer
|
||||||
|
if hasattr(self, "emb_l") and self.args.use_language_embedding:
|
||||||
|
gen_parameters = chain(gen_parameters, self.emb_l.parameters())
|
||||||
|
|
||||||
optimizer0 = get_optimizer(
|
optimizer0 = get_optimizer(
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
)
|
)
|
||||||
|
@ -769,6 +1018,10 @@ class Vits(BaseTTS):
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||||
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||||
|
# as it is probably easier for model distribution.
|
||||||
|
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageManager:
|
||||||
|
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||||
|
in a way that can be queried by language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||||
|
TTS models. Defaults to "".
|
||||||
|
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
|
||||||
|
>>> language_id_mapper = manager.language_ids
|
||||||
|
"""
|
||||||
|
|
||||||
|
language_id_mapping: Dict = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
language_ids_file_path: str = "",
|
||||||
|
config: Coqpit = None,
|
||||||
|
):
|
||||||
|
self.language_id_mapping = {}
|
||||||
|
if language_ids_file_path:
|
||||||
|
self.set_language_ids_from_file(language_ids_file_path)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
self.set_language_ids_from_config(config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_json(json_file_path: str) -> Dict:
|
||||||
|
with fsspec.open(json_file_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_json(json_file_path: str, data: dict) -> None:
|
||||||
|
with fsspec.open(json_file_path, "w") as f:
|
||||||
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self) -> int:
|
||||||
|
return len(list(self.language_id_mapping.keys()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def language_names(self) -> List:
|
||||||
|
return list(self.language_id_mapping.keys())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||||
|
"""Set language id from config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (Coqpit): Config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, int]: Language ID mapping and the number of languages.
|
||||||
|
"""
|
||||||
|
languages = set({})
|
||||||
|
for dataset in c.datasets:
|
||||||
|
if "language" in dataset:
|
||||||
|
languages.add(dataset["language"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||||
|
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
||||||
|
|
||||||
|
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||||
|
"""Set language IDs from config samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (List): Data sampled returned by `load_meta_data()`.
|
||||||
|
"""
|
||||||
|
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||||
|
|
||||||
|
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||||
|
"""Load language ids from a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the target json file.
|
||||||
|
"""
|
||||||
|
self.language_id_mapping = self._load_json(file_path)
|
||||||
|
|
||||||
|
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||||
|
"""Save language IDs to a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the output file.
|
||||||
|
"""
|
||||||
|
self._save_json(file_path, self.language_id_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_file_path(path):
|
||||||
|
"""Find the language_ids.json under the given path or the above it.
|
||||||
|
Intended to band aid the different paths returned in restored and continued training."""
|
||||||
|
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
|
||||||
|
path_continue = os.path.join(path, "language_ids.json")
|
||||||
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
if fs.exists(path_restore):
|
||||||
|
return path_restore
|
||||||
|
if fs.exists(path_continue):
|
||||||
|
return path_continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_weighted_sampler(items: list):
|
||||||
|
language_names = np.array([item[3] for item in items])
|
||||||
|
unique_language_names = np.unique(language_names).tolist()
|
||||||
|
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||||
|
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||||
|
weight_language = 1.0 / language_count
|
||||||
|
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||||
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
|
@ -7,9 +7,10 @@ import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,8 +162,10 @@ class SpeakerManager:
|
||||||
file_path (str): Path to the target json file.
|
file_path (str): Path to the target json file.
|
||||||
"""
|
"""
|
||||||
self.d_vectors = self._load_json(file_path)
|
self.d_vectors = self._load_json(file_path)
|
||||||
|
|
||||||
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||||
|
|
||||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||||
|
|
||||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||||
|
@ -209,6 +212,32 @@ class SpeakerManager:
|
||||||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||||
return d_vectors
|
return d_vectors
|
||||||
|
|
||||||
|
def get_random_speaker_id(self) -> Any:
|
||||||
|
"""Get a random d_vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: d_vector.
|
||||||
|
"""
|
||||||
|
if self.speaker_ids:
|
||||||
|
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_random_d_vector(self) -> Any:
|
||||||
|
"""Get a random D ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: d_vector.
|
||||||
|
"""
|
||||||
|
if self.d_vectors:
|
||||||
|
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_speakers(self) -> List:
|
def get_speakers(self) -> List:
|
||||||
return self.speaker_ids
|
return self.speaker_ids
|
||||||
|
|
||||||
|
@ -223,18 +252,15 @@ class SpeakerManager:
|
||||||
config_path (str): Model config file path.
|
config_path (str): Model config file path.
|
||||||
"""
|
"""
|
||||||
self.speaker_encoder_config = load_config(config_path)
|
self.speaker_encoder_config = load_config(config_path)
|
||||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||||
# normalize the input audio level and trim silences
|
|
||||||
# self.speaker_encoder_ap.do_sound_norm = True
|
|
||||||
# self.speaker_encoder_ap.do_trim_silence = True
|
|
||||||
|
|
||||||
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||||
"""Compute a d_vector from a given audio file.
|
"""Compute a d_vector from a given audio file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
wav_file (Union[str, list]): Target file path.
|
wav_file (Union[str, List[str]]): Target file path.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: Computed d_vector.
|
list: Computed d_vector.
|
||||||
|
@ -242,12 +268,16 @@ class SpeakerManager:
|
||||||
|
|
||||||
def _compute(wav_file: str):
|
def _compute(wav_file: str):
|
||||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||||
spec = torch.from_numpy(spec.T)
|
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||||
|
m_input = torch.from_numpy(m_input)
|
||||||
|
else:
|
||||||
|
m_input = torch.from_numpy(waveform)
|
||||||
|
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
spec = spec.cuda()
|
m_input = m_input.cuda()
|
||||||
spec = spec.unsqueeze(0)
|
m_input = m_input.unsqueeze(0)
|
||||||
d_vector = self.speaker_encoder.compute_embedding(spec)
|
d_vector = self.speaker_encoder.compute_embedding(m_input)
|
||||||
return d_vector
|
return d_vector
|
||||||
|
|
||||||
if isinstance(wav_file, list):
|
if isinstance(wav_file, list):
|
||||||
|
@ -364,11 +394,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||||
# new speaker manager with speaker IDs file.
|
# new speaker manager with speaker IDs file.
|
||||||
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||||
print(
|
|
||||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
if speaker_manager.num_speakers > 0:
|
||||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
print(
|
||||||
|
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||||
|
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# save file if path is defined
|
# save file if path is defined
|
||||||
if out_path:
|
if out_path:
|
||||||
out_file_path = os.path.join(out_path, "speakers.json")
|
out_file_path = os.path.join(out_path, "speakers.json")
|
||||||
|
@ -378,3 +411,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
else:
|
else:
|
||||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_speaker_weighted_sampler(items: list):
|
||||||
|
speaker_names = np.array([item[2] for item in items])
|
||||||
|
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||||
|
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||||
|
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||||
|
weight_speaker = 1.0 / speaker_count
|
||||||
|
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||||
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def text_to_seq(text, CONFIG, custom_symbols=None):
|
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
# text ot phonemes to sequence vector
|
# text ot phonemes to sequence vector
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
|
@ -23,7 +23,7 @@ def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||||
phoneme_to_sequence(
|
phoneme_to_sequence(
|
||||||
text,
|
text,
|
||||||
text_cleaner,
|
text_cleaner,
|
||||||
CONFIG.phoneme_language,
|
language if language else CONFIG.phoneme_language,
|
||||||
CONFIG.enable_eos_bos_chars,
|
CONFIG.enable_eos_bos_chars,
|
||||||
tp=CONFIG.characters,
|
tp=CONFIG.characters,
|
||||||
add_blank=CONFIG.add_blank,
|
add_blank=CONFIG.add_blank,
|
||||||
|
@ -71,6 +71,7 @@ def run_model_torch(
|
||||||
speaker_id: int = None,
|
speaker_id: int = None,
|
||||||
style_mel: torch.Tensor = None,
|
style_mel: torch.Tensor = None,
|
||||||
d_vector: torch.Tensor = None,
|
d_vector: torch.Tensor = None,
|
||||||
|
language_id: torch.Tensor = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Run a torch model for inference. It does not support batch inference.
|
"""Run a torch model for inference. It does not support batch inference.
|
||||||
|
|
||||||
|
@ -96,6 +97,7 @@ def run_model_torch(
|
||||||
"speaker_ids": speaker_id,
|
"speaker_ids": speaker_id,
|
||||||
"d_vectors": d_vector,
|
"d_vectors": d_vector,
|
||||||
"style_mel": style_mel,
|
"style_mel": style_mel,
|
||||||
|
"language_ids": language_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -160,19 +162,20 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def speaker_id_to_torch(speaker_id, cuda=False):
|
def id_to_torch(aux_id, cuda=False):
|
||||||
if speaker_id is not None:
|
if aux_id is not None:
|
||||||
speaker_id = np.asarray(speaker_id)
|
aux_id = np.asarray(aux_id)
|
||||||
speaker_id = torch.from_numpy(speaker_id)
|
aux_id = torch.from_numpy(aux_id)
|
||||||
if cuda:
|
if cuda:
|
||||||
return speaker_id.cuda()
|
return aux_id.cuda()
|
||||||
return speaker_id
|
return aux_id
|
||||||
|
|
||||||
|
|
||||||
def embedding_to_torch(d_vector, cuda=False):
|
def embedding_to_torch(d_vector, cuda=False):
|
||||||
if d_vector is not None:
|
if d_vector is not None:
|
||||||
d_vector = np.asarray(d_vector)
|
d_vector = np.asarray(d_vector)
|
||||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||||
|
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||||
if cuda:
|
if cuda:
|
||||||
return d_vector.cuda()
|
return d_vector.cuda()
|
||||||
return d_vector
|
return d_vector
|
||||||
|
@ -208,6 +211,8 @@ def synthesis(
|
||||||
use_griffin_lim=False,
|
use_griffin_lim=False,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
d_vector=None,
|
d_vector=None,
|
||||||
|
language_id=None,
|
||||||
|
language_name=None,
|
||||||
backend="torch",
|
backend="torch",
|
||||||
):
|
):
|
||||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||||
|
@ -244,6 +249,12 @@ def synthesis(
|
||||||
d_vector (torch.Tensor):
|
d_vector (torch.Tensor):
|
||||||
d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.
|
d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.
|
||||||
|
|
||||||
|
language_id (int):
|
||||||
|
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||||
|
|
||||||
|
language_name (str):
|
||||||
|
Language name corresponding to the language code used by the phonemizer. Defaults to None.
|
||||||
|
|
||||||
backend (str):
|
backend (str):
|
||||||
tf or torch. Defaults to "torch".
|
tf or torch. Defaults to "torch".
|
||||||
"""
|
"""
|
||||||
|
@ -258,15 +269,18 @@ def synthesis(
|
||||||
if hasattr(model, "make_symbols"):
|
if hasattr(model, "make_symbols"):
|
||||||
custom_symbols = model.make_symbols(CONFIG)
|
custom_symbols = model.make_symbols(CONFIG)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||||
|
|
||||||
if d_vector is not None:
|
if d_vector is not None:
|
||||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||||
|
|
||||||
|
if language_id is not None:
|
||||||
|
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||||
|
|
||||||
if not isinstance(style_mel, dict):
|
if not isinstance(style_mel, dict):
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||||
|
@ -278,7 +292,7 @@ def synthesis(
|
||||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
|
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||||
model_outputs = outputs["model_outputs"]
|
model_outputs = outputs["model_outputs"]
|
||||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
|
|
|
@ -135,3 +135,12 @@ def phoneme_cleaners(text):
|
||||||
text = remove_aux_symbols(text)
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def multilingual_cleaners(text):
|
||||||
|
"""Pipeline for multilingual text"""
|
||||||
|
text = lowercase(text)
|
||||||
|
text = replace_symbols(text, lang=None)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
|
@ -16,6 +16,60 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
TODO: Merge this with audio.py
|
TODO: Merge this with audio.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
n_fft (int):
|
||||||
|
FFT window size for STFT.
|
||||||
|
|
||||||
|
hop_length (int):
|
||||||
|
number of frames between STFT columns.
|
||||||
|
|
||||||
|
win_length (int, optional):
|
||||||
|
STFT window length.
|
||||||
|
|
||||||
|
pad_wav (bool, optional):
|
||||||
|
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||||
|
|
||||||
|
window (str, optional):
|
||||||
|
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||||
|
|
||||||
|
sample_rate (int, optional):
|
||||||
|
target audio sampling rate. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmin (int, optional):
|
||||||
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmax (int, optional):
|
||||||
|
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
n_mels (int, optional):
|
||||||
|
number of melspectrogram dimensions. Defaults to None.
|
||||||
|
|
||||||
|
use_mel (bool, optional):
|
||||||
|
If True compute the melspectrograms otherwise. Defaults to False.
|
||||||
|
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||||
|
|
||||||
|
spec_gain (float, optional):
|
||||||
|
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||||
|
|
||||||
|
power (float, optional):
|
||||||
|
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||||
|
|
||||||
|
use_htk (bool, optional):
|
||||||
|
Use HTK formula in mel filter instead of Slaney.
|
||||||
|
|
||||||
|
mel_norm (None, 'slaney', or number, optional):
|
||||||
|
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||||
|
(area normalization).
|
||||||
|
|
||||||
|
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||||
|
See `librosa.util.normalize` for a full description of supported norm values
|
||||||
|
(including `+-np.inf`).
|
||||||
|
|
||||||
|
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,6 +86,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
use_mel=False,
|
use_mel=False,
|
||||||
do_amp_to_db=False,
|
do_amp_to_db=False,
|
||||||
spec_gain=1.0,
|
spec_gain=1.0,
|
||||||
|
power=None,
|
||||||
|
use_htk=False,
|
||||||
|
mel_norm="slaney",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -45,6 +102,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
self.use_mel = use_mel
|
self.use_mel = use_mel
|
||||||
self.do_amp_to_db = do_amp_to_db
|
self.do_amp_to_db = do_amp_to_db
|
||||||
self.spec_gain = spec_gain
|
self.spec_gain = spec_gain
|
||||||
|
self.power = power
|
||||||
|
self.use_htk = use_htk
|
||||||
|
self.mel_norm = mel_norm
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||||
self.mel_basis = None
|
self.mel_basis = None
|
||||||
if use_mel:
|
if use_mel:
|
||||||
|
@ -83,6 +143,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
M = o[:, :, :, 0]
|
M = o[:, :, :, 0]
|
||||||
P = o[:, :, :, 1]
|
P = o[:, :, :, 1]
|
||||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||||
|
|
||||||
|
if self.power is not None:
|
||||||
|
S = S ** self.power
|
||||||
|
|
||||||
if self.use_mel:
|
if self.use_mel:
|
||||||
S = torch.matmul(self.mel_basis.to(x), S)
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
if self.do_amp_to_db:
|
if self.do_amp_to_db:
|
||||||
|
@ -91,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self):
|
||||||
mel_basis = librosa.filters.mel(
|
mel_basis = librosa.filters.mel(
|
||||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
self.sample_rate,
|
||||||
|
self.n_fft,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
fmin=self.mel_fmin,
|
||||||
|
fmax=self.mel_fmax,
|
||||||
|
htk=self.use_htk,
|
||||||
|
norm=self.mel_norm,
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
@ -167,7 +237,7 @@ class AudioProcessor(object):
|
||||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
mel_fmax (int, optional):
|
mel_fmax (int, optional):
|
||||||
maximum filter frequency for computing melspectrograms.. Defaults to None.
|
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
spec_gain (int, optional):
|
spec_gain (int, optional):
|
||||||
gain applied when converting amplitude to DB. Defaults to 20.
|
gain applied when converting amplitude to DB. Defaults to 20.
|
||||||
|
@ -196,6 +266,12 @@ class AudioProcessor(object):
|
||||||
do_amp_to_db_mel (bool, optional):
|
do_amp_to_db_mel (bool, optional):
|
||||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
do_rms_norm (bool, optional):
|
||||||
|
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||||
|
|
||||||
|
db_level (int, optional):
|
||||||
|
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||||
|
|
||||||
stats_path (str, optional):
|
stats_path (str, optional):
|
||||||
Path to the computed stats file. Defaults to None.
|
Path to the computed stats file. Defaults to None.
|
||||||
|
|
||||||
|
@ -233,6 +309,8 @@ class AudioProcessor(object):
|
||||||
do_sound_norm=False,
|
do_sound_norm=False,
|
||||||
do_amp_to_db_linear=True,
|
do_amp_to_db_linear=True,
|
||||||
do_amp_to_db_mel=True,
|
do_amp_to_db_mel=True,
|
||||||
|
do_rms_norm=False,
|
||||||
|
db_level=None,
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**_,
|
**_,
|
||||||
|
@ -264,6 +342,8 @@ class AudioProcessor(object):
|
||||||
self.do_sound_norm = do_sound_norm
|
self.do_sound_norm = do_sound_norm
|
||||||
self.do_amp_to_db_linear = do_amp_to_db_linear
|
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||||
self.do_amp_to_db_mel = do_amp_to_db_mel
|
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||||
|
self.do_rms_norm = do_rms_norm
|
||||||
|
self.db_level = db_level
|
||||||
self.stats_path = stats_path
|
self.stats_path = stats_path
|
||||||
# setup exp_func for db to amp conversion
|
# setup exp_func for db to amp conversion
|
||||||
if log_func == "np.log":
|
if log_func == "np.log":
|
||||||
|
@ -656,21 +736,6 @@ class AudioProcessor(object):
|
||||||
frame_period=1000 * self.hop_length / self.sample_rate,
|
frame_period=1000 * self.hop_length / self.sample_rate,
|
||||||
)
|
)
|
||||||
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
||||||
# pad = int((self.win_length / self.hop_length) / 2)
|
|
||||||
# f0 = [0.0] * pad + f0 + [0.0] * pad
|
|
||||||
# f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
|
|
||||||
# f0 = np.array(f0, dtype=np.float32)
|
|
||||||
|
|
||||||
# f01, _, _ = librosa.pyin(
|
|
||||||
# x,
|
|
||||||
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
|
||||||
# fmax=self.mel_fmax,
|
|
||||||
# frame_length=self.win_length,
|
|
||||||
# sr=self.sample_rate,
|
|
||||||
# fill_na=0.0,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# spec = self.melspectrogram(x)
|
|
||||||
return f0
|
return f0
|
||||||
|
|
||||||
### Audio Processing ###
|
### Audio Processing ###
|
||||||
|
@ -713,10 +778,33 @@ class AudioProcessor(object):
|
||||||
"""
|
"""
|
||||||
return x / abs(x).max() * 0.95
|
return x / abs(x).max() * 0.95
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rms_norm(wav, db_level=-27):
|
||||||
|
r = 10 ** (db_level / 20)
|
||||||
|
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
|
||||||
|
return wav * a
|
||||||
|
|
||||||
|
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||||
|
"""Normalize the volume based on RMS of the signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Raw waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: RMS normalized waveform.
|
||||||
|
"""
|
||||||
|
if db_level is None:
|
||||||
|
db_level = self.db_level
|
||||||
|
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||||
|
wav = self._rms_norm(x, db_level)
|
||||||
|
return wav
|
||||||
|
|
||||||
### save and load ###
|
### save and load ###
|
||||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||||
|
|
||||||
|
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): Path to the wav file.
|
filename (str): Path to the wav file.
|
||||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||||
|
@ -725,8 +813,10 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Loaded waveform.
|
np.ndarray: Loaded waveform.
|
||||||
"""
|
"""
|
||||||
if self.resample:
|
if self.resample:
|
||||||
|
# loading with resampling. It is significantly slower.
|
||||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||||
elif sr is None:
|
elif sr is None:
|
||||||
|
# SF is faster than librosa for loading files
|
||||||
x, sr = sf.read(filename)
|
x, sr = sf.read(filename)
|
||||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
||||||
else:
|
else:
|
||||||
|
@ -738,6 +828,8 @@ class AudioProcessor(object):
|
||||||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||||
if self.do_sound_norm:
|
if self.do_sound_norm:
|
||||||
x = self.sound_norm(x)
|
x = self.sound_norm(x)
|
||||||
|
if self.do_rms_norm:
|
||||||
|
x = self.rms_volume_norm(x, self.db_level)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
|
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
|
||||||
|
|
|
@ -7,6 +7,7 @@ import tarfile
|
||||||
import urllib
|
import urllib
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from os.path import expanduser
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from torch.utils.model_zoo import tqdm
|
from torch.utils.model_zoo import tqdm
|
||||||
|
@ -183,3 +184,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
|
||||||
|
"""Download dataset from kaggle.
|
||||||
|
Args:
|
||||||
|
dataset_path (str):
|
||||||
|
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
|
||||||
|
dataset_name (str): Name of the folder the dataset will be saved in.
|
||||||
|
output_path (str): Path of the location you want the dataset folder to be saved to.
|
||||||
|
"""
|
||||||
|
data_path = os.path.join(output_path, dataset_name)
|
||||||
|
try:
|
||||||
|
import kaggle # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
kaggle.api.authenticate()
|
||||||
|
print(f"""\nDownloading {dataset_name}...""")
|
||||||
|
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
|
||||||
|
except OSError:
|
||||||
|
print(
|
||||||
|
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from TTS.utils.download import download_url, extract_archive
|
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
|
||||||
|
|
||||||
|
|
||||||
def download_ljspeech(path: str):
|
def download_ljspeech(path: str):
|
||||||
|
@ -18,14 +19,106 @@ def download_ljspeech(path: str):
|
||||||
extract_archive(archive)
|
extract_archive(archive)
|
||||||
|
|
||||||
|
|
||||||
def download_vctk(path: str):
|
def download_vctk(path: str, use_kaggle: Optional[bool] = False):
|
||||||
"""Download and extract VCTK dataset
|
"""Download and extract VCTK dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): path to the directory where the dataset will be stored.
|
path (str): path to the directory where the dataset will be stored.
|
||||||
|
|
||||||
|
use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False.
|
||||||
|
"""
|
||||||
|
if use_kaggle:
|
||||||
|
download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path)
|
||||||
|
else:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
|
||||||
|
download_url(url, path)
|
||||||
|
basename = os.path.basename(url)
|
||||||
|
archive = os.path.join(path, basename)
|
||||||
|
print(" > Extracting archive file...")
|
||||||
|
extract_archive(archive)
|
||||||
|
|
||||||
|
|
||||||
|
def download_tweb(path: str):
|
||||||
|
"""Download and extract Tweb dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory where the dataset will be stored.
|
||||||
|
"""
|
||||||
|
download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
|
||||||
|
|
||||||
|
|
||||||
|
def download_libri_tts(path: str, subset: Optional[str] = "all"):
|
||||||
|
"""Download and extract libri tts dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory where the dataset will be stored.
|
||||||
|
|
||||||
|
subset (str, optional): Name of the subset to download. If you only want to download a certain
|
||||||
|
portion specify it here. Defaults to 'all'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subset_dict = {
|
||||||
|
"libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz",
|
||||||
|
"libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz",
|
||||||
|
"libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz",
|
||||||
|
"libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz",
|
||||||
|
"libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz",
|
||||||
|
"libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz",
|
||||||
|
"libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz",
|
||||||
|
}
|
||||||
|
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
if subset == "all":
|
||||||
|
for sub, val in subset_dict.items():
|
||||||
|
print(f" > Downloading {sub}...")
|
||||||
|
download_url(val, path)
|
||||||
|
basename = os.path.basename(val)
|
||||||
|
archive = os.path.join(path, basename)
|
||||||
|
print(" > Extracting archive file...")
|
||||||
|
extract_archive(archive)
|
||||||
|
print(" > All subsets downloaded")
|
||||||
|
else:
|
||||||
|
url = subset_dict[subset]
|
||||||
|
download_url(url, path)
|
||||||
|
basename = os.path.basename(url)
|
||||||
|
archive = os.path.join(path, basename)
|
||||||
|
print(" > Extracting archive file...")
|
||||||
|
extract_archive(archive)
|
||||||
|
|
||||||
|
|
||||||
|
def download_thorsten_de(path: str):
|
||||||
|
"""Download and extract Thorsten german male voice dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory where the dataset will be stored.
|
||||||
"""
|
"""
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
|
url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz"
|
||||||
|
download_url(url, path)
|
||||||
|
basename = os.path.basename(url)
|
||||||
|
archive = os.path.join(path, basename)
|
||||||
|
print(" > Extracting archive file...")
|
||||||
|
extract_archive(archive)
|
||||||
|
|
||||||
|
|
||||||
|
def download_mailabs(path: str, language: str = "english"):
|
||||||
|
"""Download and extract Mailabs dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory where the dataset will be stored.
|
||||||
|
|
||||||
|
language (str): Language subset to download. Defaults to english.
|
||||||
|
"""
|
||||||
|
language_dict = {
|
||||||
|
"english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz",
|
||||||
|
"german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz",
|
||||||
|
"french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz",
|
||||||
|
"italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz",
|
||||||
|
"spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz",
|
||||||
|
}
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
url = language_dict[language]
|
||||||
download_url(url, path)
|
download_url(url, path)
|
||||||
basename = os.path.basename(url)
|
basename = os.path.basename(url)
|
||||||
archive = os.path.join(path, basename)
|
archive = os.path.join(path, basename)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class AttrDict(dict):
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(config: Coqpit, out_path, new_fields):
|
def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
||||||
"""Copy config.json and other model files to training folder and add
|
"""Copy config.json and other model files to training folder and add
|
||||||
new fields.
|
new fields.
|
||||||
|
|
||||||
|
|
|
@ -46,36 +46,66 @@ class ModelManager(object):
|
||||||
with open(file_path, "r", encoding="utf-8") as json_file:
|
with open(file_path, "r", encoding="utf-8") as json_file:
|
||||||
self.models_dict = json.load(json_file)
|
self.models_dict = json.load(json_file)
|
||||||
|
|
||||||
def list_langs(self):
|
def _list_models(self, model_type, model_count=0):
|
||||||
print(" Name format: type/language")
|
model_list = []
|
||||||
for model_type in self.models_dict:
|
for lang in self.models_dict[model_type]:
|
||||||
for lang in self.models_dict[model_type]:
|
for dataset in self.models_dict[model_type][lang]:
|
||||||
print(f" >: {model_type}/{lang} ")
|
for model in self.models_dict[model_type][lang][dataset]:
|
||||||
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||||
|
else:
|
||||||
|
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||||
|
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||||
|
model_count += 1
|
||||||
|
return model_list
|
||||||
|
|
||||||
def list_datasets(self):
|
def _list_for_model_type(self, model_type):
|
||||||
print(" Name format: type/language/dataset")
|
print(" Name format: language/dataset/model")
|
||||||
for model_type in self.models_dict:
|
models_name_list = []
|
||||||
for lang in self.models_dict[model_type]:
|
model_count = 1
|
||||||
for dataset in self.models_dict[model_type][lang]:
|
model_type = "tts_models"
|
||||||
print(f" >: {model_type}/{lang}/{dataset}")
|
models_name_list.extend(self._list_models(model_type, model_count))
|
||||||
|
return [name.replace(model_type + "/", "") for name in models_name_list]
|
||||||
|
|
||||||
def list_models(self):
|
def list_models(self):
|
||||||
print(" Name format: type/language/dataset/model")
|
print(" Name format: type/language/dataset/model")
|
||||||
models_name_list = []
|
models_name_list = []
|
||||||
model_count = 1
|
model_count = 1
|
||||||
|
for model_type in self.models_dict:
|
||||||
|
model_list = self._list_models(model_type, model_count)
|
||||||
|
models_name_list.extend(model_list)
|
||||||
|
return models_name_list
|
||||||
|
|
||||||
|
def list_tts_models(self):
|
||||||
|
"""Print all `TTS` models and return a list of model names
|
||||||
|
|
||||||
|
Format is `language/dataset/model`
|
||||||
|
"""
|
||||||
|
return self._list_for_model_type("tts_models")
|
||||||
|
|
||||||
|
def list_vocoder_models(self):
|
||||||
|
"""Print all the `vocoder` models and return a list of model names
|
||||||
|
|
||||||
|
Format is `language/dataset/model`
|
||||||
|
"""
|
||||||
|
return self._list_for_model_type("vocoder_models")
|
||||||
|
|
||||||
|
def list_langs(self):
|
||||||
|
"""Print all the available languages"""
|
||||||
|
print(" Name format: type/language")
|
||||||
|
for model_type in self.models_dict:
|
||||||
|
for lang in self.models_dict[model_type]:
|
||||||
|
print(f" >: {model_type}/{lang} ")
|
||||||
|
|
||||||
|
def list_datasets(self):
|
||||||
|
"""Print all the datasets"""
|
||||||
|
print(" Name format: type/language/dataset")
|
||||||
for model_type in self.models_dict:
|
for model_type in self.models_dict:
|
||||||
for lang in self.models_dict[model_type]:
|
for lang in self.models_dict[model_type]:
|
||||||
for dataset in self.models_dict[model_type][lang]:
|
for dataset in self.models_dict[model_type][lang]:
|
||||||
for model in self.models_dict[model_type][lang][dataset]:
|
print(f" >: {model_type}/{lang}/{dataset}")
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
|
||||||
if os.path.exists(output_path):
|
|
||||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
|
||||||
else:
|
|
||||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
|
||||||
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
|
||||||
model_count += 1
|
|
||||||
return models_name_list
|
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name):
|
||||||
"""Download model files given the full model name.
|
"""Download model files given the full model name.
|
||||||
|
@ -121,6 +151,8 @@ class ModelManager(object):
|
||||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||||
|
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||||
|
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
||||||
|
|
||||||
# update the scale_path.npy file path in the model config.json
|
# update the scale_path.npy file path in the model config.json
|
||||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||||
|
@ -133,6 +165,12 @@ class ModelManager(object):
|
||||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||||
|
|
||||||
|
# update the speaker_encoder file path in the model config.json to the current path
|
||||||
|
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||||
|
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||||
|
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||||
|
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_path(field_name, new_path, config_path):
|
def _update_path(field_name, new_path, config_path):
|
||||||
"""Update the path in the model config.json for the current environment after download"""
|
"""Update the path in the model config.json for the current environment after download"""
|
||||||
|
@ -159,8 +197,12 @@ class ModelManager(object):
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url)
|
r = requests.get(file_url)
|
||||||
# extract the file
|
# extract the file
|
||||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
try:
|
||||||
z.extractall(output_folder)
|
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||||
|
z.extractall(output_folder)
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
print(f" > Error: Bad zip file - {file_url}")
|
||||||
|
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||||
# move the files to the outer path
|
# move the files to the outer path
|
||||||
for file_path in z.namelist()[1:]:
|
for file_path in z.namelist()[1:]:
|
||||||
src_path = os.path.join(output_folder, file_path)
|
src_path = os.path.join(output_folder, file_path)
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
|
@ -23,6 +24,7 @@ class Synthesizer(object):
|
||||||
tts_checkpoint: str,
|
tts_checkpoint: str,
|
||||||
tts_config_path: str,
|
tts_config_path: str,
|
||||||
tts_speakers_file: str = "",
|
tts_speakers_file: str = "",
|
||||||
|
tts_languages_file: str = "",
|
||||||
vocoder_checkpoint: str = "",
|
vocoder_checkpoint: str = "",
|
||||||
vocoder_config: str = "",
|
vocoder_config: str = "",
|
||||||
encoder_checkpoint: str = "",
|
encoder_checkpoint: str = "",
|
||||||
|
@ -52,6 +54,7 @@ class Synthesizer(object):
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
self.tts_config_path = tts_config_path
|
self.tts_config_path = tts_config_path
|
||||||
self.tts_speakers_file = tts_speakers_file
|
self.tts_speakers_file = tts_speakers_file
|
||||||
|
self.tts_languages_file = tts_languages_file
|
||||||
self.vocoder_checkpoint = vocoder_checkpoint
|
self.vocoder_checkpoint = vocoder_checkpoint
|
||||||
self.vocoder_config = vocoder_config
|
self.vocoder_config = vocoder_config
|
||||||
self.encoder_checkpoint = encoder_checkpoint
|
self.encoder_checkpoint = encoder_checkpoint
|
||||||
|
@ -63,6 +66,9 @@ class Synthesizer(object):
|
||||||
self.speaker_manager = None
|
self.speaker_manager = None
|
||||||
self.num_speakers = 0
|
self.num_speakers = 0
|
||||||
self.tts_speakers = {}
|
self.tts_speakers = {}
|
||||||
|
self.language_manager = None
|
||||||
|
self.num_languages = 0
|
||||||
|
self.tts_languages = {}
|
||||||
self.d_vector_dim = 0
|
self.d_vector_dim = 0
|
||||||
self.seg = self._get_segmenter("en")
|
self.seg = self._get_segmenter("en")
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
@ -110,29 +116,94 @@ class Synthesizer(object):
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||||
|
|
||||||
speaker_manager = self._init_speaker_manager()
|
speaker_manager = self._init_speaker_manager()
|
||||||
|
language_manager = self._init_language_manager()
|
||||||
|
if not self.encoder_checkpoint:
|
||||||
|
self._set_speaker_encoder_paths_from_tts_config()
|
||||||
|
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||||
|
|
||||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
if language_manager is not None:
|
||||||
|
self.tts_model = setup_tts_model(
|
||||||
|
config=self.tts_config,
|
||||||
|
speaker_manager=speaker_manager,
|
||||||
|
language_manager=language_manager,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
|
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||||
|
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||||
|
if hasattr(self.tts_config, "model_args") and hasattr(
|
||||||
|
self.tts_config.model_args, "speaker_encoder_config_path"
|
||||||
|
):
|
||||||
|
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||||
|
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||||
|
|
||||||
|
def _is_use_speaker_embedding(self):
|
||||||
|
"""Check if the speaker embedding is used in the model"""
|
||||||
|
# we handle here the case that some models use model_args some don't
|
||||||
|
use_speaker_embedding = False
|
||||||
|
if hasattr(self.tts_config, "model_args"):
|
||||||
|
use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False)
|
||||||
|
use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False)
|
||||||
|
return use_speaker_embedding
|
||||||
|
|
||||||
|
def _is_use_d_vector_file(self):
|
||||||
|
"""Check if the d-vector file is used in the model"""
|
||||||
|
# we handle here the case that some models use model_args some don't
|
||||||
|
use_d_vector_file = False
|
||||||
|
if hasattr(self.tts_config, "model_args"):
|
||||||
|
config = self.tts_config.model_args
|
||||||
|
use_d_vector_file = config.get("use_d_vector_file", False)
|
||||||
|
config = self.tts_config
|
||||||
|
use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False)
|
||||||
|
return use_d_vector_file
|
||||||
|
|
||||||
def _init_speaker_manager(self):
|
def _init_speaker_manager(self):
|
||||||
"""Initialize the SpeakerManager"""
|
"""Initialize the SpeakerManager"""
|
||||||
# setup if multi-speaker settings are in the global model config
|
# setup if multi-speaker settings are in the global model config
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
|
||||||
|
if self._is_use_speaker_embedding():
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("speakers_file", None):
|
elif speakers_file:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
|
||||||
|
|
||||||
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
|
if self._is_use_d_vector_file():
|
||||||
|
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("d_vector_file", None):
|
elif d_vector_file:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
def _init_speaker_encoder(self, speaker_manager):
|
||||||
|
"""Initialize the SpeakerEncoder"""
|
||||||
|
if self.encoder_checkpoint:
|
||||||
|
if speaker_manager is None:
|
||||||
|
speaker_manager = SpeakerManager(
|
||||||
|
encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||||
|
return speaker_manager
|
||||||
|
|
||||||
|
def _init_language_manager(self):
|
||||||
|
"""Initialize the LanguageManager"""
|
||||||
|
# setup if multi-lingual settings are in the global model config
|
||||||
|
language_manager = None
|
||||||
|
if check_config_and_model_args(self.tts_config, "use_language_embedding", True):
|
||||||
|
if self.tts_languages_file:
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
||||||
|
elif self.tts_config.get("language_ids_file", None):
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
|
||||||
|
else:
|
||||||
|
language_manager = LanguageManager(config=self.tts_config)
|
||||||
|
return language_manager
|
||||||
|
|
||||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
"""Load the vocoder model.
|
"""Load the vocoder model.
|
||||||
|
|
||||||
|
@ -174,13 +245,21 @@ class Synthesizer(object):
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
|
def tts(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
speaker_name: str = "",
|
||||||
|
language_name: str = "",
|
||||||
|
speaker_wav: Union[str, List[str]] = None,
|
||||||
|
style_wav=None,
|
||||||
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): input text.
|
text (str): input text.
|
||||||
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||||
speaker_wav ():
|
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||||
|
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -196,29 +275,49 @@ class Synthesizer(object):
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
speaker_id = None
|
speaker_id = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
||||||
if speaker_idx and isinstance(speaker_idx, str):
|
if speaker_name and isinstance(speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
# get the speaker embedding from the saved d_vectors.
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0]
|
||||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
|
||||||
|
|
||||||
elif not speaker_idx and not speaker_wav:
|
elif not speaker_name and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-speaker model. "
|
" [!] Look like you use a multi-speaker model. "
|
||||||
"You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model."
|
"You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
else:
|
else:
|
||||||
if speaker_idx:
|
if speaker_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}."
|
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
|
||||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# handle multi-lingaul
|
||||||
|
language_id = None
|
||||||
|
if self.tts_languages_file or (
|
||||||
|
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||||
|
):
|
||||||
|
if language_name and isinstance(language_name, str):
|
||||||
|
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
|
elif not language_name:
|
||||||
|
raise ValueError(
|
||||||
|
" [!] Look like you use a multi-lingual model. "
|
||||||
|
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f" [!] Missing language_ids.json file path for selecting language {language_name}."
|
||||||
|
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||||
|
)
|
||||||
|
|
||||||
# compute a new d_vector from the given clip.
|
# compute a new d_vector from the given clip.
|
||||||
if speaker_wav is not None:
|
if speaker_wav is not None:
|
||||||
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||||
|
@ -234,6 +333,8 @@ class Synthesizer(object):
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
|
language_id=language_id,
|
||||||
|
language_name=language_name,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
|
import wave
|
||||||
|
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(path):
|
||||||
|
"""Reads a .wav file.
|
||||||
|
|
||||||
|
Takes the path, and returns (PCM audio data, sample rate).
|
||||||
|
"""
|
||||||
|
with contextlib.closing(wave.open(path, "rb")) as wf:
|
||||||
|
num_channels = wf.getnchannels()
|
||||||
|
assert num_channels == 1
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
assert sample_width == 2
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
assert sample_rate in (8000, 16000, 32000, 48000)
|
||||||
|
pcm_data = wf.readframes(wf.getnframes())
|
||||||
|
return pcm_data, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def write_wave(path, audio, sample_rate):
|
||||||
|
"""Writes a .wav file.
|
||||||
|
|
||||||
|
Takes path, PCM audio data, and sample rate.
|
||||||
|
"""
|
||||||
|
with contextlib.closing(wave.open(path, "wb")) as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(audio)
|
||||||
|
|
||||||
|
|
||||||
|
class Frame(object):
|
||||||
|
"""Represents a "frame" of audio data."""
|
||||||
|
|
||||||
|
def __init__(self, _bytes, timestamp, duration):
|
||||||
|
self.bytes = _bytes
|
||||||
|
self.timestamp = timestamp
|
||||||
|
self.duration = duration
|
||||||
|
|
||||||
|
|
||||||
|
def frame_generator(frame_duration_ms, audio, sample_rate):
|
||||||
|
"""Generates audio frames from PCM audio data.
|
||||||
|
|
||||||
|
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||||
|
the sample rate.
|
||||||
|
|
||||||
|
Yields Frames of the requested duration.
|
||||||
|
"""
|
||||||
|
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
||||||
|
offset = 0
|
||||||
|
timestamp = 0.0
|
||||||
|
duration = (float(n) / sample_rate) / 2.0
|
||||||
|
while offset + n < len(audio):
|
||||||
|
yield Frame(audio[offset : offset + n], timestamp, duration)
|
||||||
|
timestamp += duration
|
||||||
|
offset += n
|
||||||
|
|
||||||
|
|
||||||
|
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
||||||
|
"""Filters out non-voiced audio frames.
|
||||||
|
|
||||||
|
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||||
|
the voiced audio.
|
||||||
|
|
||||||
|
Uses a padded, sliding window algorithm over the audio frames.
|
||||||
|
When more than 90% of the frames in the window are voiced (as
|
||||||
|
reported by the VAD), the collector triggers and begins yielding
|
||||||
|
audio frames. Then the collector waits until 90% of the frames in
|
||||||
|
the window are unvoiced to detrigger.
|
||||||
|
|
||||||
|
The window is padded at the front and back to provide a small
|
||||||
|
amount of silence or the beginnings/endings of speech around the
|
||||||
|
voiced frames.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
sample_rate - The audio sample rate, in Hz.
|
||||||
|
frame_duration_ms - The frame duration in milliseconds.
|
||||||
|
padding_duration_ms - The amount to pad the window, in milliseconds.
|
||||||
|
vad - An instance of webrtcvad.Vad.
|
||||||
|
frames - a source of audio frames (sequence or generator).
|
||||||
|
|
||||||
|
Returns: A generator that yields PCM audio data.
|
||||||
|
"""
|
||||||
|
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
||||||
|
# We use a deque for our sliding window/ring buffer.
|
||||||
|
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||||
|
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
||||||
|
# NOTTRIGGERED state.
|
||||||
|
triggered = False
|
||||||
|
|
||||||
|
voiced_frames = []
|
||||||
|
for frame in frames:
|
||||||
|
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||||
|
|
||||||
|
# sys.stdout.write('1' if is_speech else '0')
|
||||||
|
if not triggered:
|
||||||
|
ring_buffer.append((frame, is_speech))
|
||||||
|
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||||
|
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||||
|
# the ring buffer are voiced frames, then enter the
|
||||||
|
# TRIGGERED state.
|
||||||
|
if num_voiced > 0.9 * ring_buffer.maxlen:
|
||||||
|
triggered = True
|
||||||
|
# sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,))
|
||||||
|
# We want to yield all the audio we see from now until
|
||||||
|
# we are NOTTRIGGERED, but we have to start with the
|
||||||
|
# audio that's already in the ring buffer.
|
||||||
|
for f, _ in ring_buffer:
|
||||||
|
voiced_frames.append(f)
|
||||||
|
ring_buffer.clear()
|
||||||
|
else:
|
||||||
|
# We're in the TRIGGERED state, so collect the audio data
|
||||||
|
# and add it to the ring buffer.
|
||||||
|
voiced_frames.append(frame)
|
||||||
|
ring_buffer.append((frame, is_speech))
|
||||||
|
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||||
|
# If more than 90% of the frames in the ring buffer are
|
||||||
|
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||||
|
# audio we've collected.
|
||||||
|
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||||
|
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||||
|
triggered = False
|
||||||
|
yield b"".join([f.bytes for f in voiced_frames])
|
||||||
|
ring_buffer.clear()
|
||||||
|
voiced_frames = []
|
||||||
|
# If we have any leftover voiced audio when we run out of input,
|
||||||
|
# yield it.
|
||||||
|
if voiced_frames:
|
||||||
|
yield b"".join([f.bytes for f in voiced_frames])
|
||||||
|
|
||||||
|
|
||||||
|
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
||||||
|
|
||||||
|
vad = webrtcvad.Vad(int(aggressiveness))
|
||||||
|
frames = list(frame_generator(30, audio, sample_rate))
|
||||||
|
segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames)
|
||||||
|
|
||||||
|
return segments
|
|
@ -113,8 +113,10 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
|
||||||
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||||
lr_scheduler_disc (torch.optim.Scheduler):
|
lr_scheduler_disc (torch.optim.Scheduler):
|
||||||
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
||||||
lr_scheduler_dict_params (dict):
|
lr_scheduler_disc_params (dict):
|
||||||
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||||
|
scheduler_after_epoch (bool):
|
||||||
|
Whether to update the learning rate schedulers after each epoch. Defaults to True.
|
||||||
use_pqmf (bool):
|
use_pqmf (bool):
|
||||||
enable / disable PQMF for subband approximation at training. Defaults to False.
|
enable / disable PQMF for subband approximation at training. Defaults to False.
|
||||||
steps_to_start_discriminator (int):
|
steps_to_start_discriminator (int):
|
||||||
|
@ -173,6 +175,7 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
|
||||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||||
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||||
|
scheduler_after_epoch: bool = True
|
||||||
|
|
||||||
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
|
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
|
||||||
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
|
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
|
||||||
|
|
|
@ -202,7 +202,9 @@ class GAN(BaseVocoder):
|
||||||
) -> Tuple[Dict, np.ndarray]:
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for training."""
|
"""Call `_log()` for training."""
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
self._log("train", ap, batch, outputs)
|
figures, audios = self._log("eval", ap, batch, outputs)
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
|
@ -214,7 +216,9 @@ class GAN(BaseVocoder):
|
||||||
) -> Tuple[Dict, np.ndarray]:
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for evaluation."""
|
"""Call `_log()` for evaluation."""
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
self._log("eval", ap, batch, outputs)
|
figures, audios = self._log("eval", ap, batch, outputs)
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -3,10 +3,15 @@
|
||||||
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
|
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
|
||||||
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
|
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
|
||||||
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
|
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
|
||||||
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
using MAS, as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
||||||
It is a feed-forward model with x67.12 real-time factor on a GPU.
|
It is a feed-forward model with x67.12 real-time factor on a GPU.
|
||||||
|
|
||||||
|
🐸 YourTTS is a multi-speaker and multi-lingual TTS model that can perform voice conversion and zero-shot speaker adaptation.
|
||||||
|
It can also learn a new language or voice with a ~ 1 minute long audio clip. This is a big open gate for training
|
||||||
|
TTS models in low-resources languages. 🐸 YourTTS uses VITS as the backbone architecture coupled with a speaker encoder model.
|
||||||
|
|
||||||
## Important resources & papers
|
## Important resources & papers
|
||||||
|
- 🐸 YourTTS: https://arxiv.org/abs/2112.02418
|
||||||
- VITS: https://arxiv.org/pdf/2106.06103.pdf
|
- VITS: https://arxiv.org/pdf/2106.06103.pdf
|
||||||
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
|
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
|
||||||
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
|
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
|
||||||
|
|
|
@ -180,7 +180,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.rcParams["figure.figsize"] = (50, 20)
|
plt.rcParams["figure.figsize"] = (50, 20)
|
||||||
barplot = sns.barplot(x, y)
|
barplot = sns.barplot(x=x, y=y)
|
||||||
if save_path:
|
if save_path:
|
||||||
fig = barplot.get_figure()
|
fig = barplot.get_figure()
|
||||||
fig.savefig(os.path.join(save_path, "phoneme_dist"))
|
fig.savefig(os.path.join(save_path, "phoneme_dist"))
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
mailabs_path = "/home/julian/workspace/mailabs/**"
|
||||||
|
dataset_paths = glob(mailabs_path)
|
||||||
|
dataset_config = [
|
||||||
|
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
||||||
|
for path in dataset_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=16000,
|
||||||
|
win_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
num_mels=80,
|
||||||
|
preemphasis=0.0,
|
||||||
|
ref_level_db=20,
|
||||||
|
log_func="np.log",
|
||||||
|
do_trim_silence=False,
|
||||||
|
trim_db=23.0,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
spec_gain=1.0,
|
||||||
|
signal_norm=True,
|
||||||
|
do_amp_to_db_linear=False,
|
||||||
|
resample=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
vitsArgs = VitsArgs(
|
||||||
|
use_language_embedding=True,
|
||||||
|
embedded_language_dim=4,
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
use_sdp=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
model_args=vitsArgs,
|
||||||
|
audio=audio_config,
|
||||||
|
run_name="vits_vctk",
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
batch_size=32,
|
||||||
|
eval_batch_size=16,
|
||||||
|
batch_group_size=0,
|
||||||
|
num_loader_workers=4,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="multilingual_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
print_step=25,
|
||||||
|
use_language_weighted_sampler=True,
|
||||||
|
print_eval=False,
|
||||||
|
mixed_precision=False,
|
||||||
|
sort_by_audio_len=True,
|
||||||
|
min_seq_len=32 * 256 * 4,
|
||||||
|
max_seq_len=160000,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=dataset_config,
|
||||||
|
characters={
|
||||||
|
"pad": "_",
|
||||||
|
"eos": "&",
|
||||||
|
"bos": "*",
|
||||||
|
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
|
||||||
|
"punctuations": "!¡'(),-.:;¿? ",
|
||||||
|
"phonemes": None,
|
||||||
|
"unique": True,
|
||||||
|
},
|
||||||
|
test_sentences=[
|
||||||
|
[
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"mary_ann",
|
||||||
|
None,
|
||||||
|
"en_US",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
|
||||||
|
"ezwa",
|
||||||
|
None,
|
||||||
|
"fr_FR",
|
||||||
|
],
|
||||||
|
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
|
||||||
|
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
|
||||||
|
# init speaker manager for multi-speaker training
|
||||||
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
|
speaker_manager = SpeakerManager()
|
||||||
|
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||||
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
|
language_manager = LanguageManager(config=config)
|
||||||
|
config.model_args.num_languages = language_manager.num_languages
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = Vits(config, speaker_manager, language_manager)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
|
trainer.fit()
|
|
@ -5,12 +5,14 @@ from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="vctk", meta_file_train="", language="en-us", path=os.path.join(output_path, "../VCTK/")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
|
@ -31,10 +33,14 @@ audio_config = BaseAudioConfig(
|
||||||
resample=True,
|
resample=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vitsArgs = VitsArgs(
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
model_args=vitsArgs,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
run_name="vits_vctk",
|
run_name="vits_vctk",
|
||||||
use_speaker_embedding=True,
|
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
batch_group_size=5,
|
batch_group_size=5,
|
||||||
|
@ -45,7 +51,6 @@ config = VitsConfig(
|
||||||
epochs=1000,
|
epochs=1000,
|
||||||
text_cleaner="english_cleaners",
|
text_cleaner="english_cleaners",
|
||||||
use_phonemes=True,
|
use_phonemes=True,
|
||||||
phoneme_language="en-us",
|
|
||||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
compute_input_seq_cache=True,
|
compute_input_seq_cache=True,
|
||||||
print_step=25,
|
print_step=25,
|
||||||
|
|
|
@ -26,3 +26,5 @@ unidic-lite==1.0.8
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||||
fsspec>=2021.04.0
|
fsspec>=2021.04.0
|
||||||
pyworld
|
pyworld
|
||||||
|
webrtcvad
|
||||||
|
torchaudio
|
||||||
|
|
|
@ -38,3 +38,14 @@ def run_cli(command):
|
||||||
|
|
||||||
def get_test_data_config():
|
def get_test_data_config():
|
||||||
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
||||||
|
|
||||||
|
|
||||||
|
def assertHasAttr(test_obj, obj, intendedAttr):
|
||||||
|
# from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead
|
||||||
|
testBool = hasattr(obj, intendedAttr)
|
||||||
|
test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||||
|
|
||||||
|
|
||||||
|
def assertHasNotAttr(test_obj, obj, intendedAttr):
|
||||||
|
testBool = hasattr(obj, intendedAttr)
|
||||||
|
test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import get_tests_output_path, run_cli
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
|
||||||
|
dataset_config_en = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="pt-br",
|
||||||
|
)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
class TestFindUniquePhonemes(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def test_espeak_phonemes():
|
||||||
|
# prepare the config
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
datasets=[dataset_config_en, dataset_config_pt],
|
||||||
|
)
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# run test
|
||||||
|
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_no_espeak_phonemes():
|
||||||
|
# prepare the config
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
datasets=[dataset_config_en, dataset_config_pt],
|
||||||
|
)
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# run test
|
||||||
|
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
|
@ -0,0 +1,29 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
class TestRemoveSilenceVAD(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def test():
|
||||||
|
# set paths
|
||||||
|
wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
|
||||||
|
output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
|
||||||
|
|
||||||
|
# resample audios
|
||||||
|
run_cli(
|
||||||
|
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
|
||||||
|
)
|
||||||
|
|
||||||
|
# run test
|
||||||
|
run_cli(
|
||||||
|
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
|
||||||
|
)
|
||||||
|
run_cli(f'rm -rf "{output_resample_path}"')
|
||||||
|
run_cli(f'rm -rf "{output_path}"')
|
|
@ -13,7 +13,7 @@ file_path = get_tests_input_path()
|
||||||
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
|
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.type() == "torch.FloatTensor"
|
assert output.type() == "torch.FloatTensor"
|
||||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 80, 240) # B x T x D
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
||||||
assert output.shape[0] == 1
|
assert output.shape[0] == 1
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
|
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
|
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.type() == "torch.FloatTensor"
|
assert output.type() == "torch.FloatTensor"
|
||||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 80, 240) # B x D x T
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
||||||
assert output.shape[0] == 1
|
assert output.shape[0] == 1
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
from TTS.speaker_encoder.utils.io import save_checkpoint
|
from TTS.speaker_encoder.utils.io import save_checkpoint
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
config.audio.resample = True
|
config.audio.resample = True
|
||||||
|
|
||||||
# create a dummy speaker encoder
|
# create a dummy speaker encoder
|
||||||
model = setup_model(config)
|
model = setup_speaker_encoder_model(config)
|
||||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||||
|
|
||||||
# load audio processor and speaker encoder
|
# load audio processor and speaker encoder
|
||||||
|
@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
# load a sample audio and compute embedding
|
# load a sample audio and compute embedding
|
||||||
waveform = ap.load_wav(sample_wav_path)
|
waveform = ap.load_wav(sample_wav_path)
|
||||||
mel = ap.melspectrogram(waveform)
|
mel = ap.melspectrogram(waveform)
|
||||||
d_vector = manager.compute_d_vector(mel.T)
|
d_vector = manager.compute_d_vector(mel)
|
||||||
assert d_vector.shape[1] == 256
|
assert d_vector.shape[1] == 256
|
||||||
|
|
||||||
# compute d_vector directly from an input file
|
# compute d_vector directly from an input file
|
||||||
|
|
|
@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
items = ljspeech(c.data_path, "metadata.csv")
|
items = ljspeech(c.data_path, "metadata.csv")
|
||||||
|
|
||||||
|
# add a default language because now the TTSDataset expect a language
|
||||||
|
language = ""
|
||||||
|
items = [[*item, language] for item in items]
|
||||||
|
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||||
|
|
||||||
|
# Fixing random state to avoid random fails
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
dataset_config_en = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="pt-br",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adding the EN samples twice to create an unbalanced dataset
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_balanced(lang_1, lang_2):
|
||||||
|
return 0.85 < lang_1 / lang_2 < 1.2
|
||||||
|
|
||||||
|
|
||||||
|
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||||
|
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||||
|
en, pt = 0, 0
|
||||||
|
for index in ids:
|
||||||
|
if train_samples[index][3] == "en":
|
||||||
|
en += 1
|
||||||
|
else:
|
||||||
|
pt += 1
|
||||||
|
|
||||||
|
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||||
|
|
||||||
|
weighted_sampler = get_language_weighted_sampler(train_samples)
|
||||||
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
|
en, pt = 0, 0
|
||||||
|
for index in ids:
|
||||||
|
if train_samples[index][3] == "en":
|
||||||
|
en += 1
|
||||||
|
else:
|
||||||
|
pt += 1
|
||||||
|
|
||||||
|
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"en": 0,
|
||||||
|
"fr-fr": 1,
|
||||||
|
"pt-br": 2
|
||||||
|
}
|
|
@ -0,0 +1,240 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
|
||||||
|
from TTS.config import load_config
|
||||||
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
class TestVits(unittest.TestCase):
|
||||||
|
def test_init_multispeaker(self):
|
||||||
|
num_speakers = 10
|
||||||
|
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
model = Vits(args)
|
||||||
|
assertHasAttr(self, model, "emb_g")
|
||||||
|
|
||||||
|
args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
|
||||||
|
model = Vits(args)
|
||||||
|
assertHasNotAttr(self, model, "emb_g")
|
||||||
|
|
||||||
|
args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
|
||||||
|
model = Vits(args)
|
||||||
|
assertHasNotAttr(self, model, "emb_g")
|
||||||
|
|
||||||
|
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
|
||||||
|
model = Vits(args)
|
||||||
|
self.assertEqual(model.embedded_speaker_dim, 101)
|
||||||
|
|
||||||
|
def test_init_multilingual(self):
|
||||||
|
args = VitsArgs(language_ids_file=None, use_language_embedding=False)
|
||||||
|
model = Vits(args)
|
||||||
|
self.assertEqual(model.language_manager, None)
|
||||||
|
self.assertEqual(model.embedded_language_dim, 0)
|
||||||
|
self.assertEqual(model.emb_l, None)
|
||||||
|
|
||||||
|
args = VitsArgs(language_ids_file=LANG_FILE)
|
||||||
|
model = Vits(args)
|
||||||
|
self.assertNotEqual(model.language_manager, None)
|
||||||
|
self.assertEqual(model.embedded_language_dim, 0)
|
||||||
|
self.assertEqual(model.emb_l, None)
|
||||||
|
|
||||||
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
|
||||||
|
model = Vits(args)
|
||||||
|
self.assertNotEqual(model.language_manager, None)
|
||||||
|
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||||
|
self.assertNotEqual(model.emb_l, None)
|
||||||
|
|
||||||
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
|
||||||
|
model = Vits(args)
|
||||||
|
self.assertNotEqual(model.language_manager, None)
|
||||||
|
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||||
|
self.assertNotEqual(model.emb_l, None)
|
||||||
|
|
||||||
|
def test_get_aux_input(self):
|
||||||
|
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
||||||
|
args = VitsArgs()
|
||||||
|
model = Vits(args)
|
||||||
|
aux_out = model.get_aux_input(aux_input)
|
||||||
|
|
||||||
|
speaker_id = torch.randint(10, (1,))
|
||||||
|
language_id = torch.randint(10, (1,))
|
||||||
|
d_vector = torch.rand(1, 128)
|
||||||
|
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
|
||||||
|
aux_out = model.get_aux_input(aux_input)
|
||||||
|
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
||||||
|
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
|
||||||
|
self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
|
||||||
|
|
||||||
|
def test_voice_conversion(self):
|
||||||
|
num_speakers = 10
|
||||||
|
spec_len = 101
|
||||||
|
spec_effective_len = 50
|
||||||
|
|
||||||
|
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
model = Vits(args)
|
||||||
|
|
||||||
|
ref_inp = torch.randn(1, spec_len, 513)
|
||||||
|
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
|
||||||
|
ref_spk_id = torch.randint(1, num_speakers, (1,))
|
||||||
|
tgt_spk_id = torch.randint(1, num_speakers, (1,))
|
||||||
|
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
|
||||||
|
|
||||||
|
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
|
||||||
|
self.assertEqual(y_mask.shape, (1, 1, spec_len))
|
||||||
|
self.assertEqual(y_mask.sum(), ref_inp_len[0])
|
||||||
|
self.assertEqual(z.shape, (1, args.hidden_channels, spec_len))
|
||||||
|
self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
|
||||||
|
self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))
|
||||||
|
|
||||||
|
def _init_inputs(self, config):
|
||||||
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
|
input_lengths[-1] = 128
|
||||||
|
spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
||||||
|
spec_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
spec_lengths[-1] = spec.size(2)
|
||||||
|
waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||||
|
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
||||||
|
|
||||||
|
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
|
||||||
|
self.assertEqual(
|
||||||
|
output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||||
|
)
|
||||||
|
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
|
||||||
|
self.assertEqual(output_dict["alignments"].max(), 1)
|
||||||
|
self.assertEqual(output_dict["alignments"].min(), 0)
|
||||||
|
self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
|
self.assertEqual(
|
||||||
|
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||||
|
)
|
||||||
|
if encoder_config:
|
||||||
|
self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||||
|
self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||||
|
else:
|
||||||
|
self.assertEqual(output_dict["gt_spk_emb"], None)
|
||||||
|
self.assertEqual(output_dict["syn_spk_emb"], None)
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
num_speakers = 0
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
config.model_args.spec_segment_size = 10
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
||||||
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
|
||||||
|
def test_multispeaker_forward(self):
|
||||||
|
num_speakers = 10
|
||||||
|
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
config.model_args.spec_segment_size = 10
|
||||||
|
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||||
|
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||||
|
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
output_dict = model.forward(
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
|
||||||
|
)
|
||||||
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
|
||||||
|
def test_multilingual_forward(self):
|
||||||
|
num_speakers = 10
|
||||||
|
num_langs = 3
|
||||||
|
|
||||||
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
|
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||||
|
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||||
|
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||||
|
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
output_dict = model.forward(
|
||||||
|
input_dummy,
|
||||||
|
input_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
waveform,
|
||||||
|
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||||
|
)
|
||||||
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
|
||||||
|
def test_secl_forward(self):
|
||||||
|
num_speakers = 10
|
||||||
|
num_langs = 3
|
||||||
|
|
||||||
|
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
|
||||||
|
speaker_encoder_config.model_params["use_torch_spec"] = True
|
||||||
|
speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
|
||||||
|
speaker_manager = SpeakerManager()
|
||||||
|
speaker_manager.speaker_encoder = speaker_encoder
|
||||||
|
|
||||||
|
args = VitsArgs(
|
||||||
|
language_ids_file=LANG_FILE,
|
||||||
|
use_language_embedding=True,
|
||||||
|
spec_segment_size=10,
|
||||||
|
use_speaker_encoder_as_loss=True,
|
||||||
|
)
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
|
config.audio.sample_rate = 16000
|
||||||
|
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||||
|
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||||
|
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||||
|
|
||||||
|
model = Vits(config, speaker_manager=speaker_manager).to(device)
|
||||||
|
output_dict = model.forward(
|
||||||
|
input_dummy,
|
||||||
|
input_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
waveform,
|
||||||
|
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||||
|
)
|
||||||
|
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
num_speakers = 0
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
_ = model.inference(input_dummy)
|
||||||
|
|
||||||
|
def test_multispeaker_inference(self):
|
||||||
|
num_speakers = 10
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||||
|
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids})
|
||||||
|
|
||||||
|
def test_multilingual_inference(self):
|
||||||
|
num_speakers = 10
|
||||||
|
num_langs = 3
|
||||||
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||||
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
|
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||||
|
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||||
|
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
|
||||||
|
model = Vits(config).to(device)
|
||||||
|
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
|
@ -0,0 +1,62 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech-0"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_d_vector_file = True
|
||||||
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,91 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
dataset_config_en = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="pt-br",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
|
||||||
|
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
|
||||||
|
],
|
||||||
|
datasets=[dataset_config_en, dataset_config_pt],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multilingual mode
|
||||||
|
config.model_args.use_language_embedding = True
|
||||||
|
config.use_language_embedding = True
|
||||||
|
|
||||||
|
# deactivate multispeaker mode
|
||||||
|
config.model_args.use_speaker_embedding = False
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_d_vector_file = True
|
||||||
|
config.use_d_vector_file = True
|
||||||
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.d_vector_dim = 256
|
||||||
|
config.d_vector_dim = 256
|
||||||
|
|
||||||
|
# duration predictor
|
||||||
|
config.model_args.use_sdp = True
|
||||||
|
config.use_sdp = True
|
||||||
|
|
||||||
|
# deactivate language sampler
|
||||||
|
config.use_language_weighted_sampler = False
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,88 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
dataset_config_en = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="pt-br",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech", None, "en"],
|
||||||
|
["Be a voice, not an echo.", "ljspeech", None, "pt-br"],
|
||||||
|
],
|
||||||
|
datasets=[dataset_config_en, dataset_config_pt],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multilingual mode
|
||||||
|
config.model_args.use_language_embedding = True
|
||||||
|
config.use_language_embedding = True
|
||||||
|
# active multispeaker mode
|
||||||
|
config.model_args.use_speaker_embedding = True
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
|
||||||
|
# deactivate multispeaker d-vec mode
|
||||||
|
config.model_args.use_d_vector_file = False
|
||||||
|
config.use_d_vector_file = False
|
||||||
|
|
||||||
|
# duration predictor
|
||||||
|
config.model_args.use_sdp = False
|
||||||
|
config.use_sdp = False
|
||||||
|
|
||||||
|
# active language sampler
|
||||||
|
config.use_language_weighted_sampler = True
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,63 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_speaker_embedding = True
|
||||||
|
config.model_args.use_d_vector_file = False
|
||||||
|
config.model_args.d_vector_file = None
|
||||||
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -25,7 +25,7 @@ config = VitsConfig(
|
||||||
print_step=1,
|
print_step=1,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
"Be a voice, not an echo.",
|
["Be a voice, not an echo."],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
config.audio.do_trim_silence = True
|
config.audio.do_trim_silence = True
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_tests_output_path, run_cli
|
from tests import get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
@ -17,21 +18,30 @@ def test_run_all_models():
|
||||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||||
model_names = manager.list_models()
|
model_names = manager.list_models()
|
||||||
for model_name in model_names:
|
for model_name in model_names:
|
||||||
|
print(f"\n > Run - {model_name}")
|
||||||
model_path, _, _ = manager.download_model(model_name)
|
model_path, _, _ = manager.download_model(model_name)
|
||||||
if "tts_models" in model_name:
|
if "tts_models" in model_name:
|
||||||
local_download_dir = os.path.dirname(model_path)
|
local_download_dir = os.path.dirname(model_path)
|
||||||
# download and run the model
|
# download and run the model
|
||||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||||
|
language_files = glob.glob(local_download_dir + "/language*")
|
||||||
|
language_id = ""
|
||||||
if len(speaker_files) > 0:
|
if len(speaker_files) > 0:
|
||||||
# multi-speaker model
|
# multi-speaker model
|
||||||
if "speaker_ids" in speaker_files[0]:
|
if "speaker_ids" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||||
elif "speakers" in speaker_files[0]:
|
elif "speakers" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||||
|
|
||||||
|
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
|
||||||
|
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||||
|
language_id = language_manager.language_names[0]
|
||||||
|
|
||||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||||
run_cli(
|
run_cli(
|
||||||
f"tts --model_name {model_name} "
|
f"tts --model_name {model_name} "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
|
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# single-speaker model
|
# single-speaker model
|
||||||
|
|
Loading…
Reference in New Issue