diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index de7e439d..133346f6 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -37,8 +37,8 @@ In the worse case provide steps to reproduce the behaviour.
You can either run `TTS/bin/collect_env_info.py`
```bash
-wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_details.py
-python collect_env_details.py
+wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py
+python collect_env_info.py
```
or fill in the fields below manually.
diff --git a/.github/workflows/aux_tests.yml b/.github/workflows/aux_tests.yml
index d5fe1bb3..59ba572d 100644
--- a/.github/workflows/aux_tests.yml
+++ b/.github/workflows/aux_tests.yml
@@ -22,25 +22,22 @@ jobs:
experimental: [false]
steps:
- uses: actions/checkout@v2
- - uses: actions/cache@v1
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
+ cache: 'pip'
+ cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
- sudo apt update
- sudo apt install -y git make
- sudo apt install -y python3-wheel gcc
+ sudo apt-get update
+ sudo apt-get install -y git make gcc
make system-deps
- - name: Upgrade pip
- run: python3 -m pip install --upgrade pip
+ - name: Install/upgrade Python setup deps
+ run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml
index d31e71cf..02faa7f6 100644
--- a/.github/workflows/pypi-release.yml
+++ b/.github/workflows/pypi-release.yml
@@ -7,7 +7,7 @@ defaults:
shell:
bash
jobs:
- build-package:
+ build-sdist:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
@@ -23,10 +23,63 @@ jobs:
with:
python-version: 3.8
- run: |
- python -m pip install -U pip setuptools twine toml
- python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))' | pip install -r /dev/stdin
+ python -m pip install -U pip setuptools wheel build
- run: |
- python setup.py sdist
+ python -m build
+ - run: |
+ pip install dist/*.tar.gz
+ - uses: actions/upload-artifact@v2
+ with:
+ name: sdist
+ path: dist/*.tar.gz
+ build-wheels:
+ runs-on: ubuntu-20.04
+ strategy:
+ matrix:
+ python-version: ["3.6", "3.7", "3.8", "3.9"]
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - run: |
+ python -m pip install -U pip setuptools wheel build
+ - run: |
+ python -m build
+ - run: |
+ python -m pip install dist/*.whl
+ - uses: actions/upload-artifact@v2
+ with:
+ name: wheel-${{ matrix.python-version }}
+ path: dist/*.whl
+ publish-artifacts:
+ runs-on: ubuntu-20.04
+ needs: [build-sdist, build-wheels]
+ steps:
+ - run: |
+ mkdir dist
+ - uses: actions/download-artifact@v2
+ with:
+ name: "sdist"
+ path: "dist/"
+ - uses: actions/download-artifact@v2
+ with:
+ name: "wheel-3.6"
+ path: "dist/"
+ - uses: actions/download-artifact@v2
+ with:
+ name: "wheel-3.7"
+ path: "dist/"
+ - uses: actions/download-artifact@v2
+ with:
+ name: "wheel-3.8"
+ path: "dist/"
+ - uses: actions/download-artifact@v2
+ with:
+ name: "wheel-3.9"
+ path: "dist/"
+ - run: |
+ ls -lh dist/
- name: Setup PyPI config
run: |
cat << EOF > ~/.pypirc
@@ -34,5 +87,10 @@ jobs:
username=__token__
password=${{ secrets.PYPI_TOKEN }}
EOF
+ - uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
- run: |
- twine upload --repository pypi dist/*.tar.gz
+ python -m pip install twine
+ - run: |
+ twine upload --repository pypi dist/*
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 4a30c26d..8d1e1af4 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -22,25 +22,22 @@ jobs:
experimental: [false]
steps:
- uses: actions/checkout@v2
- - uses: actions/cache@v1
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
+ cache: 'pip'
+ cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
- sudo apt update
- sudo apt install -y git make
- sudo apt install -y python3-wheel gcc
+ sudo apt-get update
+ sudo apt-get install -y git make gcc
make system-deps
- - name: Upgrade pip
- run: python3 -m pip install --upgrade pip
+ - name: Install/upgrade Python setup deps
+ run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml
index d05dca90..e352a117 100644
--- a/.github/workflows/tts_tests.yml
+++ b/.github/workflows/tts_tests.yml
@@ -22,25 +22,22 @@ jobs:
experimental: [false]
steps:
- uses: actions/checkout@v2
- - uses: actions/cache@v1
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
+ cache: 'pip'
+ cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
- sudo apt update
- sudo apt install -y git make
- sudo apt install -y python3-wheel gcc
+ sudo apt-get update
+ sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- - name: Upgrade pip
- run: python3 -m pip install --upgrade pip
+ - name: Install/upgrade Python setup deps
+ run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
diff --git a/.github/workflows/vocoder_tests.yml b/.github/workflows/vocoder_tests.yml
index 69e74dbf..24ae9e3f 100644
--- a/.github/workflows/vocoder_tests.yml
+++ b/.github/workflows/vocoder_tests.yml
@@ -22,25 +22,22 @@ jobs:
experimental: [false]
steps:
- uses: actions/checkout@v2
- - uses: actions/cache@v1
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
+ cache: 'pip'
+ cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
- sudo apt update
- sudo apt install -y git make
- sudo apt install -y python3-wheel gcc
+ sudo apt-get update
+ sudo apt-get install -y git make gcc
make system-deps
- - name: Upgrade pip
- run: python3 -m pip install --upgrade pip
+ - name: Install/upgrade Python setup deps
+ run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
diff --git a/.github/workflows/zoo_tests.yml b/.github/workflows/zoo_tests.yml
index 0dc4457b..f973dd0e 100644
--- a/.github/workflows/zoo_tests.yml
+++ b/.github/workflows/zoo_tests.yml
@@ -22,25 +22,22 @@ jobs:
experimental: [false]
steps:
- uses: actions/checkout@v2
- - uses: actions/cache@v1
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
+ cache: 'pip'
+ cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
- sudo apt update
- sudo apt install -y git make
- sudo apt install -y python3-wheel gcc
+ sudo apt-get update
+ sudo apt-get install -y git make gcc
make system-deps
- - name: Upgrade pip
- run: python3 -m pip install --upgrade pip
+ - name: Install/upgrade Python setup deps
+ run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
diff --git a/.gitignore b/.gitignore
index 64d1f0d5..7e9da0d8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -128,6 +128,8 @@ core
recipes/WIP/*
recipes/ljspeech/LJSpeech-1.1/*
recipes/vctk/VCTK/*
+recipes/**/*.npy
+recipes/**/*.json
VCTK-Corpus-removed-silence/*
# ignore training logs
@@ -161,4 +163,5 @@ speakers.json
internal/*
*_pitch.npy
*_phoneme.npy
-wandb
\ No newline at end of file
+wandb
+depot/*
\ No newline at end of file
diff --git a/README.md b/README.md
index fd9cd27c..4686ac67 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
#
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
-🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
+🐸TTS comes with pretrained models, tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
[](https://github.com/coqui-ai/TTS/actions)
[](https://badge.fury.io/py/TTS)
@@ -135,6 +135,66 @@ $ make install
If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system).
+## Use TTS
+
+### Single Speaker Models
+
+- List provided models:
+
+ ```
+ $ tts --list_models
+ ```
+
+- Run TTS with default models:
+
+ ```
+ $ tts --text "Text for TTS"
+ ```
+
+- Run a TTS model with its default vocoder model:
+
+ ```
+ $ tts --text "Text for TTS" --model_name "//
+ ```
+
+- Run with specific TTS and vocoder models from the list:
+
+ ```
+ $ tts --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path
+ ```
+
+- Run your own TTS model (Using Griffin-Lim Vocoder):
+
+ ```
+ $ tts --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
+ ```
+
+- Run your own TTS and Vocoder models:
+ ```
+ $ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
+ --vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
+ ```
+
+### Multi-speaker Models
+
+- List the available speakers and choose as among them:
+
+ ```
+ $ tts --model_name "//" --list_speaker_idxs
+ ```
+
+- Run the multi-speaker TTS model with the target speaker ID:
+
+ ```
+ $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx
+ ```
+
+- Run your own multi-speaker TTS model:
+
+ ```
+ $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx
+ ```
+
## Directory Structure
```
|- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
diff --git a/TTS/.models.json b/TTS/.models.json
index 44c5fc6c..61a3257d 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -1,5 +1,17 @@
{
"tts_models": {
+ "multilingual":{
+ "multi-dataset":{
+ "your_tts":{
+ "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
+ "github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
+ "default_vocoder": null,
+ "commit": "e9a1953e",
+ "license": "CC BY-NC-ND 4.0",
+ "contact": "egolge@coqui.ai"
+ }
+ }
+ },
"en": {
"ek1": {
"tacotron2": {
@@ -149,7 +161,7 @@
"commit": "bdab788d",
"license": "MIT",
"contact": "",
- "default_vocoder": null
+ "default_vocoder": "vocoder_models/uk/mai/multiband-melgan"
}
}
},
@@ -301,6 +313,17 @@
"commit": "3900448"
}
}
+ },
+ "uk": {
+ "mai": {
+ "multiband-melgan": {
+ "github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/vocoder_models--uk--mai--multiband-melgan.zip",
+ "author":"@robinhad",
+ "commit": "bdab788d",
+ "license": "MIT",
+ "contact": ""
+ }
+ }
}
}
}
\ No newline at end of file
diff --git a/TTS/VERSION b/TTS/VERSION
index f7abe273..79a2734b 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.4.2
\ No newline at end of file
+0.5.0
\ No newline at end of file
diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py
index 0af98ff1..7b489fd6 100755
--- a/TTS/bin/extract_tts_spectrograms.py
+++ b/TTS/bin/extract_tts_spectrograms.py
@@ -12,7 +12,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
-from TTS.tts.utils.speakers import get_speaker_manager
+from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
@@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False,
verbose=verbose,
- speaker_id_mapping=speaker_manager.speaker_ids,
- d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
+ speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
+ d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
@@ -234,8 +234,13 @@ def main(args): # pylint: disable=redefined-outer-name
# use eval and training partitions
meta_data = meta_data_train + meta_data_eval
- # parse speakers
- speaker_manager = get_speaker_manager(c, args, meta_data_train)
+ # init speaker manager
+ if c.use_speaker_embedding:
+ speaker_manager = SpeakerManager(data_items=meta_data)
+ elif c.use_d_vector_file:
+ speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
+ else:
+ speaker_manager = None
# setup model
model = setup_model(c)
diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py
new file mode 100644
index 00000000..d3143ca3
--- /dev/null
+++ b/TTS/bin/find_unique_phonemes.py
@@ -0,0 +1,62 @@
+"""Find all the unique characters in a dataset"""
+import argparse
+import multiprocessing
+from argparse import RawTextHelpFormatter
+
+from tqdm.contrib.concurrent import process_map
+
+from TTS.config import load_config
+from TTS.tts.datasets import load_tts_samples
+from TTS.tts.utils.text import text2phone
+
+
+def compute_phonemes(item):
+ try:
+ text = item[0]
+ language = item[-1]
+ ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
+ except:
+ return []
+ return list(set(ph))
+
+
+def main():
+ # pylint: disable=W0601
+ global c
+ # pylint: disable=bad-option-value
+ parser = argparse.ArgumentParser(
+ description="""Find all the unique characters or phonemes in a dataset.\n\n"""
+ """
+ Example runs:
+
+ python TTS/bin/find_unique_chars.py --config_path config.json
+ """,
+ formatter_class=RawTextHelpFormatter,
+ )
+ parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
+ args = parser.parse_args()
+
+ c = load_config(args.config_path)
+
+ # load all datasets
+ train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
+ items = train_items + eval_items
+ print("Num items:", len(items))
+
+ phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
+ phones = []
+ for ph in phonemes:
+ phones.extend(ph)
+ phones = set(phones)
+ lower_phones = filter(lambda c: c.islower(), phones)
+ phones_force_lower = [c.lower() for c in phones]
+ phones_force_lower = set(phones_force_lower)
+
+ print(f" > Number of unique phonemes: {len(phones)}")
+ print(f" > Unique phonemes: {''.join(sorted(phones))}")
+ print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
+ print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py
new file mode 100755
index 00000000..9070f2da
--- /dev/null
+++ b/TTS/bin/remove_silence_using_vad.py
@@ -0,0 +1,89 @@
+import argparse
+import glob
+import multiprocessing
+import os
+import pathlib
+
+from tqdm.contrib.concurrent import process_map
+
+from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
+
+
+def remove_silence(filepath):
+ output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
+ # ignore if the file exists
+ if os.path.exists(output_path) and not args.force:
+ return
+
+ # create all directory structure
+ pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ # load wave
+ audio, sample_rate = read_wave(filepath)
+
+ # get speech segments
+ segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness)
+
+ segments = list(segments)
+ num_segments = len(segments)
+ flag = False
+ # create the output wave
+ if num_segments != 0:
+ for i, segment in reversed(list(enumerate(segments))):
+ if i >= 1:
+ if not flag:
+ concat_segment = segment
+ flag = True
+ else:
+ concat_segment = segment + concat_segment
+ else:
+ if flag:
+ segment = segment + concat_segment
+ # print("Saving: ", output_path)
+ write_wave(output_path, segment, sample_rate)
+ return
+ else:
+ print("> Just Copying the file to:", output_path)
+ # if fail to remove silence just write the file
+ write_wave(output_path, audio, sample_rate)
+ return
+
+
+def preprocess_audios():
+ files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
+ print("> Number of files: ", len(files))
+ if not args.force:
+ print("> Ignoring files that already exist in the output directory.")
+
+ if files:
+ # create threads
+ num_threads = multiprocessing.cpu_count()
+ process_map(remove_silence, files, max_workers=num_threads, chunksize=15)
+ else:
+ print("> No files Found !")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2"
+ )
+ parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
+ parser.add_argument(
+ "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
+ )
+ parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
+ parser.add_argument(
+ "-g",
+ "--glob",
+ type=str,
+ default="**/*.wav",
+ help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
+ )
+ parser.add_argument(
+ "-a",
+ "--aggressiveness",
+ type=int,
+ default=2,
+ help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
+ )
+ args = parser.parse_args()
+ preprocess_audios()
diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index fb2e41b4..509b3da6 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -23,72 +23,76 @@ def str2bool(v):
def main():
- # pylint: disable=bad-option-value
- parser = argparse.ArgumentParser(
- description="""Synthesize speech on command line.\n\n"""
- """You can either use your trained model or choose a model from the provided list.\n\n"""
- """If you don't specify any models, then it uses LJSpeech based English model.\n\n"""
- """
- # Example Runs:
+ description = """Synthesize speech on command line.
- ## Single Speaker Models
+You can either use your trained model or choose a model from the provided list.
- - list provided models
+If you don't specify any models, then it uses LJSpeech based English model.
+
+## Example Runs
+
+### Single Speaker Models
+
+- List provided models:
```
- $ ./TTS/bin/synthesize.py --list_models
+ $ tts --list_models
```
- - run tts with default models.
+- Run TTS with default models:
```
- $ ./TTS/bin synthesize.py --text "Text for TTS"
+ $ tts --text "Text for TTS"
```
- - run a tts model with its default vocoder model.
+- Run a TTS model with its default vocoder model:
```
- $ ./TTS/bin synthesize.py --text "Text for TTS" --model_name "//
+ $ tts --text "Text for TTS" --model_name "//
```
- - run with specific tts and vocoder models from the list
+- Run with specific TTS and vocoder models from the list:
```
- $ ./TTS/bin/synthesize.py --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path
+ $ tts --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path
```
- - run your own TTS model (Using Griffin-Lim Vocoder)
+- Run your own TTS model (Using Griffin-Lim Vocoder):
```
- $ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
+ $ tts --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
```
- - run your own TTS and Vocoder models
+- Run your own TTS and Vocoder models:
```
- $ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
+ $ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
```
- ## MULTI-SPEAKER MODELS
+### Multi-speaker Models
- - list the available speakers and choose as among them.
+- List the available speakers and choose as among them:
```
- $ ./TTS/bin/synthesize.py --model_name "//" --list_speaker_idxs
+ $ tts --model_name "//" --list_speaker_idxs
```
- - run the multi-speaker TTS model with the target speaker ID.
+- Run the multi-speaker TTS model with the target speaker ID:
```
- $ ./TTS/bin/synthesize.py --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx
+ $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx
```
- - run your own multi-speaker TTS model.
+- Run your own multi-speaker TTS model:
```
- $ ./TTS/bin/synthesize.py --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx
+ $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth.tar --speakers_file_path path/to/speaker.json --speaker_idx
```
- """,
+ """
+ # We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
+ # documentation in sync more easily.
+ parser = argparse.ArgumentParser(
+ description=description.replace(" ```\n", ""),
formatter_class=RawTextHelpFormatter,
)
@@ -98,7 +102,7 @@ def main():
nargs="?",
const=True,
default=False,
- help="list available pre-trained tts and vocoder models.",
+ help="list available pre-trained TTS and vocoder models.",
)
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
@@ -107,7 +111,7 @@ def main():
"--model_name",
type=str,
default="tts_models/en/ljspeech/tacotron2-DDC",
- help="Name of one of the pre-trained tts models in format //",
+ help="Name of one of the pre-trained TTS models in format //",
)
parser.add_argument(
"--vocoder_name",
@@ -148,12 +152,19 @@ def main():
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
+ parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
parser.add_argument(
"--speaker_idx",
type=str,
help="Target speaker ID for a multi-speaker TTS model.",
default=None,
)
+ parser.add_argument(
+ "--language_idx",
+ type=str,
+ help="Target language ID for a multi-lingual TTS model.",
+ default=None,
+ )
parser.add_argument(
"--speaker_wav",
nargs="+",
@@ -169,6 +180,14 @@ def main():
const=True,
default=False,
)
+ parser.add_argument(
+ "--list_language_idxs",
+ help="List available language ids for the defined multi-lingual model.",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ )
# aux args
parser.add_argument(
"--save_spectogram",
@@ -180,7 +199,7 @@ def main():
args = parser.parse_args()
# print the description if either text or list_models is not set
- if args.text is None and not args.list_models and not args.list_speaker_idxs:
+ if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
parser.parse_args(["-h"])
# load model manager
@@ -190,6 +209,7 @@ def main():
model_path = None
config_path = None
speakers_file_path = None
+ language_ids_file_path = None
vocoder_path = None
vocoder_config_path = None
encoder_path = None
@@ -213,6 +233,7 @@ def main():
model_path = args.model_path
config_path = args.config_path
speakers_file_path = args.speakers_file_path
+ language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
@@ -227,6 +248,7 @@ def main():
model_path,
config_path,
speakers_file_path,
+ language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
@@ -242,6 +264,14 @@ def main():
print(synthesizer.tts_model.speaker_manager.speaker_ids)
return
+ # query langauge ids of a multi-lingual model.
+ if args.list_language_idxs:
+ print(
+ " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
+ )
+ print(synthesizer.tts_model.language_manager.language_id_mapping)
+ return
+
# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
print(
@@ -254,7 +284,7 @@ def main():
print(" > Text: {}".format(args.text))
# kick it
- wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
+ wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
# save the results
print(" > Saving output to {}".format(args.out_path))
diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py
index ad6d95f7..8c364300 100644
--- a/TTS/bin/train_encoder.py
+++ b/TTS/bin/train_encoder.py
@@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
-from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
+from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
from TTS.speaker_encoder.utils.training import init_training
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
@@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name
global meta_data_eval
ap = AudioProcessor(**c.audio)
- model = setup_model(c)
+ model = setup_speaker_encoder_model(c)
optimizer = RAdam(model.parameters(), lr=c.lr)
diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py
index e28e9dec..0f8c4760 100644
--- a/TTS/bin/train_tts.py
+++ b/TTS/bin/train_tts.py
@@ -1,9 +1,11 @@
import os
+import torch
-from TTS.config import load_config, register_config
+from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
+from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@@ -45,15 +47,39 @@ def main():
ap = AudioProcessor(**config.audio)
# init speaker manager
- if config.use_speaker_embedding:
+ if check_config_and_model_args(config, "use_speaker_embedding", True):
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
- elif config.use_d_vector_file:
- speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
+ if hasattr(config, "model_args"):
+ config.model_args.num_speakers = speaker_manager.num_speakers
+ else:
+ config.num_speakers = speaker_manager.num_speakers
+ elif check_config_and_model_args(config, "use_d_vector_file", True):
+ if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
+ speaker_manager = SpeakerManager(
+ d_vectors_file_path=config.model_args.d_vector_file,
+ encoder_model_path=config.model_args.speaker_encoder_model_path,
+ encoder_config_path=config.model_args.speaker_encoder_config_path,
+ use_cuda=torch.cuda.is_available(),
+ )
+ else:
+ speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
+ config.num_speakers = speaker_manager.num_speakers
+ if hasattr(config, "model_args"):
+ config.model_args.num_speakers = speaker_manager.num_speakers
else:
speaker_manager = None
+ if check_config_and_model_args(config, "use_language_embedding", True):
+ language_manager = LanguageManager(config=config)
+ if hasattr(config, "model_args"):
+ config.model_args.num_languages = language_manager.num_languages
+ else:
+ config.num_languages = language_manager.num_languages
+ else:
+ language_manager = None
+
# init the model from config
- model = setup_model(config, speaker_manager)
+ model = setup_model(config, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py
index f626163f..5c905295 100644
--- a/TTS/config/__init__.py
+++ b/TTS/config/__init__.py
@@ -95,3 +95,38 @@ def load_config(config_path: str) -> None:
config = config_class()
config.from_dict(config_dict)
return config
+
+
+def check_config_and_model_args(config, arg_name, value):
+ """Check the give argument in `config.model_args` if exist or in `config` for
+ the given value.
+
+ Return False if the argument does not exist in `config.model_args` or `config`.
+ This is to patch up the compatibility between models with and without `model_args`.
+
+ TODO: Remove this in the future with a unified approach.
+ """
+ if hasattr(config, "model_args"):
+ if arg_name in config.model_args:
+ return config.model_args[arg_name] == value
+ if hasattr(config, arg_name):
+ return config[arg_name] == value
+ return False
+
+
+def get_from_config_or_model_args(config, arg_name):
+ """Get the given argument from `config.model_args` if exist or in `config`."""
+ if hasattr(config, "model_args"):
+ if arg_name in config.model_args:
+ return config.model_args[arg_name]
+ return config[arg_name]
+
+
+def get_from_config_or_model_args_with_default(config, arg_name, def_val):
+ """Get the given argument from `config.model_args` if exist or in `config`."""
+ if hasattr(config, "model_args"):
+ if arg_name in config.model_args:
+ return config.model_args[arg_name]
+ if hasattr(config, arg_name):
+ return config[arg_name]
+ return def_val
diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py
index d91bf2b6..f2bd40ad 100644
--- a/TTS/config/shared_configs.py
+++ b/TTS/config/shared_configs.py
@@ -60,6 +60,12 @@ class BaseAudioConfig(Coqpit):
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
+ do_rms_norm (bool, optional):
+ enable/disable RMS volume normalization when loading an audio file. Defaults to False.
+
+ db_level (int, optional):
+ dB level used for rms normalization. The range is -99 to 0. Defaults to None.
+
power (float):
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
artifacts in the synthesized voice. Defaults to 1.5.
@@ -116,6 +122,9 @@ class BaseAudioConfig(Coqpit):
# silence trimming
do_trim_silence: bool = True
trim_db: int = 45
+ # rms volume normalization
+ do_rms_norm: bool = False
+ db_level: float = None
# griffin-lim params
power: float = 1.5
griffin_lim_iters: int = 60
@@ -184,9 +193,12 @@ class BaseDatasetConfig(Coqpit):
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
Defaults to None.
- unused_speakers (List):
+ ignored_speakers (List):
List of speakers IDs that are not used at the training. Default None.
+ language (str):
+ Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to None.
+
meta_file_val (str):
Name of the dataset meta file that defines the instances used at validation.
@@ -198,7 +210,8 @@ class BaseDatasetConfig(Coqpit):
name: str = ""
path: str = ""
meta_file_train: str = ""
- ununsed_speakers: List[str] = None
+ ignored_speakers: List[str] = None
+ language: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""
@@ -335,6 +348,8 @@ class BaseTrainingConfig(Coqpit):
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
+ use_language_weighted_sampler: bool = False
+
# paths
output_path: str = None
# distributed
diff --git a/TTS/server/server.py b/TTS/server/server.py
index c6d67141..f2512582 100644
--- a/TTS/server/server.py
+++ b/TTS/server/server.py
@@ -100,7 +100,15 @@ if args.vocoder_path is not None:
# load models
synthesizer = Synthesizer(
- model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
+ tts_checkpoint=model_path,
+ tts_config_path=config_path,
+ tts_speakers_file=speakers_file_path,
+ tts_languages_file=None,
+ vocoder_checkpoint=vocoder_path,
+ vocoder_config=vocoder_config_path,
+ encoder_checkpoint="",
+ encoder_config="",
+ use_cuda=args.use_cuda,
)
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
@@ -165,7 +173,7 @@ def tts():
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
- wavs = synthesizer.tts(text, speaker_idx=speaker_idx, style_wav=style_wav)
+ wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")
diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py
index 6b2b0dd4..5b0fee22 100644
--- a/TTS/speaker_encoder/dataset.py
+++ b/TTS/speaker_encoder/dataset.py
@@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
feats = torch.stack(feats)
labels = torch.stack(labels)
- return feats.transpose(1, 2), labels
+ return feats, labels
diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py
index de5bb007..ec394cdb 100644
--- a/TTS/speaker_encoder/models/lstm.py
+++ b/TTS/speaker_encoder/models/lstm.py
@@ -1,7 +1,9 @@
import numpy as np
import torch
+import torchaudio
from torch import nn
+from TTS.speaker_encoder.models.resnet import PreEmphasis
from TTS.utils.io import load_fsspec
@@ -33,9 +35,22 @@ class LSTMWithoutProjection(nn.Module):
class LSTMSpeakerEncoder(nn.Module):
- def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
+ def __init__(
+ self,
+ input_dim,
+ proj_dim=256,
+ lstm_dim=768,
+ num_lstm_layers=3,
+ use_lstm_with_projection=True,
+ use_torch_spec=False,
+ audio_config=None,
+ ):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
+ self.use_torch_spec = use_torch_spec
+ self.audio_config = audio_config
+ self.proj_dim = proj_dim
+
layers = []
# choise LSTM layer
if use_lstm_with_projection:
@@ -46,6 +61,38 @@ class LSTMSpeakerEncoder(nn.Module):
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
+
+ if self.use_torch_spec:
+ self.torch_spec = torch.nn.Sequential(
+ PreEmphasis(audio_config["preemphasis"]),
+ # TorchSTFT(
+ # n_fft=audio_config["fft_size"],
+ # hop_length=audio_config["hop_length"],
+ # win_length=audio_config["win_length"],
+ # sample_rate=audio_config["sample_rate"],
+ # window="hamming_window",
+ # mel_fmin=0.0,
+ # mel_fmax=None,
+ # use_htk=True,
+ # do_amp_to_db=False,
+ # n_mels=audio_config["num_mels"],
+ # power=2.0,
+ # use_mel=True,
+ # mel_norm=None,
+ # )
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_config["sample_rate"],
+ n_fft=audio_config["fft_size"],
+ win_length=audio_config["win_length"],
+ hop_length=audio_config["hop_length"],
+ window_fn=torch.hamming_window,
+ n_mels=audio_config["num_mels"],
+ ),
+ )
+ else:
+ self.torch_spec = None
+
self._init_layers()
def _init_layers(self):
@@ -55,22 +102,33 @@ class LSTMSpeakerEncoder(nn.Module):
elif "weight" in name:
nn.init.xavier_normal_(param)
- def forward(self, x):
- # TODO: implement state passing for lstms
+ def forward(self, x, l2_norm=True):
+ """Forward pass of the model.
+
+ Args:
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
+ to compute the spectrogram on-the-fly.
+ l2_norm (bool): Whether to L2-normalize the outputs.
+
+ Shapes:
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
+ """
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.use_torch_spec:
+ x.squeeze_(1)
+ x = self.torch_spec(x)
+ x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
- d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
- else:
+ d = d[:, -1]
+ if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
- def inference(self, x):
- d = self.layers.forward(x)
- if self.use_lstm_with_projection:
- d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
- else:
- d = torch.nn.functional.normalize(d, p=2, dim=1)
+ def inference(self, x, l2_norm=True):
+ d = self.forward(x, l2_norm=l2_norm)
return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py
index fcc850d7..d6c3dad4 100644
--- a/TTS/speaker_encoder/models/resnet.py
+++ b/TTS/speaker_encoder/models/resnet.py
@@ -1,10 +1,25 @@
import numpy as np
import torch
+import torchaudio
from torch import nn
+# from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
+class PreEmphasis(nn.Module):
+ def __init__(self, coefficient=0.97):
+ super().__init__()
+ self.coefficient = coefficient
+ self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
+
+ def forward(self, x):
+ assert len(x.size()) == 2
+
+ x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
+ return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
+
+
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@@ -70,12 +85,18 @@ class ResNetSpeakerEncoder(nn.Module):
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
+ use_torch_spec=False,
+ audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
+ self.use_torch_spec = use_torch_spec
+ self.audio_config = audio_config
+ self.proj_dim = proj_dim
+
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
@@ -88,6 +109,36 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
+ if self.use_torch_spec:
+ self.torch_spec = torch.nn.Sequential(
+ PreEmphasis(audio_config["preemphasis"]),
+ # TorchSTFT(
+ # n_fft=audio_config["fft_size"],
+ # hop_length=audio_config["hop_length"],
+ # win_length=audio_config["win_length"],
+ # sample_rate=audio_config["sample_rate"],
+ # window="hamming_window",
+ # mel_fmin=0.0,
+ # mel_fmax=None,
+ # use_htk=True,
+ # do_amp_to_db=False,
+ # n_mels=audio_config["num_mels"],
+ # power=2.0,
+ # use_mel=True,
+ # mel_norm=None,
+ # )
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_config["sample_rate"],
+ n_fft=audio_config["fft_size"],
+ win_length=audio_config["win_length"],
+ hop_length=audio_config["hop_length"],
+ window_fn=torch.hamming_window,
+ n_mels=audio_config["num_mels"],
+ ),
+ )
+ else:
+ self.torch_spec = None
+
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
@@ -140,9 +191,23 @@ class ResNetSpeakerEncoder(nn.Module):
return out
def forward(self, x, l2_norm=False):
- x = x.transpose(1, 2)
+ """Forward pass of the model.
+
+ Args:
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
+ to compute the spectrogram on-the-fly.
+ l2_norm (bool): Whether to L2-normalize the outputs.
+
+ Shapes:
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
+ """
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
+ x.squeeze_(1)
+ # if you torch spec compute it otherwise use the mel spec computed by the AP
+ if self.use_torch_spec:
+ x = self.torch_spec(x)
+
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
@@ -175,11 +240,19 @@ class ResNetSpeakerEncoder(nn.Module):
return x
@torch.no_grad()
- def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
+ def inference(self, x, l2_norm=False):
+ return self.forward(x, l2_norm)
+
+ @torch.no_grad()
+ def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
+ # map to the waveform size
+ if self.use_torch_spec:
+ num_frames = num_frames * self.audio_config["hop_length"]
+
max_len = x.shape[1]
if max_len < num_frames:
@@ -195,11 +268,10 @@ class ResNetSpeakerEncoder(nn.Module):
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
- embeddings = self.forward(frames_batch, l2_norm=True)
+ embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
-
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py
index 1981fbe9..b8aa4093 100644
--- a/TTS/speaker_encoder/utils/generic_utils.py
+++ b/TTS/speaker_encoder/utils/generic_utils.py
@@ -170,16 +170,24 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
-def setup_model(c):
- if c.model_params["model_name"].lower() == "lstm":
+def setup_speaker_encoder_model(config: "Coqpit"):
+ if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
- c.model_params["input_dim"],
- c.model_params["proj_dim"],
- c.model_params["lstm_dim"],
- c.model_params["num_lstm_layers"],
+ config.model_params["input_dim"],
+ config.model_params["proj_dim"],
+ config.model_params["lstm_dim"],
+ config.model_params["num_lstm_layers"],
+ use_torch_spec=config.model_params.get("use_torch_spec", False),
+ audio_config=config.audio,
+ )
+ elif config.model_params["model_name"].lower() == "resnet":
+ model = ResNetSpeakerEncoder(
+ input_dim=config.model_params["input_dim"],
+ proj_dim=config.model_params["proj_dim"],
+ log_input=config.model_params.get("log_input", False),
+ use_torch_spec=config.model_params.get("use_torch_spec", False),
+ audio_config=config.audio,
)
- elif c.model_params["model_name"].lower() == "resnet":
- model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
return model
diff --git a/TTS/trainer.py b/TTS/trainer.py
index 2a2cfc46..7bffb386 100644
--- a/TTS/trainer.py
+++ b/TTS/trainer.py
@@ -202,7 +202,7 @@ class Trainer:
os.makedirs(output_path, exist_ok=True)
# copy training assets to the output folder
- copy_model_files(config, output_path, new_fields=None)
+ copy_model_files(config, output_path)
# init class members
self.args = args
@@ -439,7 +439,7 @@ class Trainer:
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
- except (KeyError, RuntimeError):
+ except (KeyError, RuntimeError, ValueError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py
index d490e6e6..36c948af 100644
--- a/TTS/tts/configs/vits_config.py
+++ b/TTS/tts/configs/vits_config.py
@@ -82,8 +82,14 @@ class VitsConfig(BaseTTSConfig):
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
- test_sentences (List[str]):
- List of sentences to be used for testing.
+ test_sentences (List[List]):
+ List of sentences with speaker and language information to be used for testing.
+
+ language_ids_file (str):
+ Path to the language ids file.
+
+ use_language_embedding (bool):
+ If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
@@ -117,6 +123,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
+ speaker_encoder_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True
@@ -130,13 +137,13 @@ class VitsConfig(BaseTTSConfig):
add_blank: bool = True
# testing
- test_sentences: List[str] = field(
+ test_sentences: List[List] = field(
default_factory=lambda: [
- "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
- "Be a voice, not an echo.",
- "I'm sorry Dave. I'm afraid I can't do that.",
- "This cake is great. It's so delicious and moist.",
- "Prior to November 22, 1963.",
+ ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
+ ["Be a voice, not an echo."],
+ ["I'm sorry Dave. I'm afraid I can't do that."],
+ ["This cake is great. It's so delicious and moist."],
+ ["Prior to November 22, 1963."],
]
)
@@ -146,29 +153,15 @@ class VitsConfig(BaseTTSConfig):
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
+ language_ids_file: str = None
+ use_language_embedding: bool = False
# use d-vectors
use_d_vector_file: bool = False
- d_vector_file: str = False
+ d_vector_file: str = None
d_vector_dim: int = None
def __post_init__(self):
- # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
- if self.num_speakers > 0:
- self.model_args.num_speakers = self.num_speakers
-
- # speaker embedding settings
- if self.use_speaker_embedding:
- self.model_args.use_speaker_embedding = True
- if self.speakers_file:
- self.model_args.speakers_file = self.speakers_file
- if self.speaker_embedding_channels:
- self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
-
- # d-vector settings
- if self.use_d_vector_file:
- self.model_args.use_d_vector_file = True
- if self.d_vector_dim is not None and self.d_vector_dim > 0:
- self.model_args.d_vector_dim = self.d_vector_dim
- if self.d_vector_file:
- self.model_args.d_vector_file = self.d_vector_file
+ for key, val in self.model_args.items():
+ if hasattr(self, key):
+ self[key] = val
diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py
index 4fae974f..40eed7e3 100644
--- a/TTS/tts/datasets/__init__.py
+++ b/TTS/tts/datasets/__init__.py
@@ -67,16 +67,22 @@ def load_tts_samples(
root_path = dataset["path"]
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
+ ignored_speakers = dataset["ignored_speakers"]
+ language = dataset["language"]
+
# setup the right data processor
if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set
- meta_data_train = formatter(root_path, meta_file_train)
+ meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
+ meta_data_train = [[*item, language] for item in meta_data_train]
+
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
- meta_data_eval = formatter(root_path, meta_file_val)
+ meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
+ meta_data_eval = [[*item, language] for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py
index 04314bab..2f20c865 100644
--- a/TTS/tts/datasets/dataset.py
+++ b/TTS/tts/datasets/dataset.py
@@ -37,6 +37,7 @@ class TTSDataset(Dataset):
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
+ language_id_mapping: Dict = None,
use_noise_augment: bool = False,
verbose: bool = False,
):
@@ -122,7 +123,9 @@ class TTSDataset(Dataset):
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
+ self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
+
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
@@ -197,10 +200,10 @@ class TTSDataset(Dataset):
def load_data(self, idx):
item = self.items[idx]
- if len(item) == 4:
- text, wav_file, speaker_name, attn_file = item
+ if len(item) == 5:
+ text, wav_file, speaker_name, language_name, attn_file = item
else:
- text, wav_file, speaker_name = item
+ text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
@@ -218,7 +221,7 @@ class TTSDataset(Dataset):
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
- self.phoneme_language,
+ language_name if language_name else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
@@ -260,6 +263,7 @@ class TTSDataset(Dataset):
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
+ "language_name": language_name,
"wav_file_name": os.path.basename(wav_file),
}
return sample
@@ -269,6 +273,9 @@ class TTSDataset(Dataset):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
+ func_args[3] = (
+ item[3] if item[3] else func_args[3]
+ ) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
@@ -335,7 +342,6 @@ class TTSDataset(Dataset):
else:
lengths = np.array([len(ins[0]) for ins in self.items])
- # sort items based on the sequence length in ascending order
idxs = np.argsort(lengths)
new_items = []
ignored = []
@@ -345,10 +351,7 @@ class TTSDataset(Dataset):
ignored.append(idx)
else:
new_items.append(self.items[idx])
-
# shuffle batch groups
- # create batches with similar length items
- # the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
@@ -356,14 +359,8 @@ class TTSDataset(Dataset):
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
-
- if len(new_items) == 0:
- raise RuntimeError(" [!] No items left after filtering.")
-
- # update items to the new sorted items
self.items = new_items
- # logging
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
@@ -413,9 +410,14 @@ class TTSDataset(Dataset):
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
+ # get language ids from language names
+ if self.language_id_mapping is not None:
+ language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
+ else:
+ language_ids = None
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
- wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
+ wav_files_names = list(batch["wav_file_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
@@ -466,6 +468,9 @@ class TTSDataset(Dataset):
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
+ if language_ids is not None:
+ language_ids = torch.LongTensor(language_ids)
+
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
@@ -528,6 +533,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
+ "language_ids": language_ids,
}
raise TypeError(
@@ -542,7 +548,6 @@ class TTSDataset(Dataset):
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
-
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.
diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py
index 425eb0cd..1f23f85e 100644
--- a/TTS/tts/datasets/formatters.py
+++ b/TTS/tts/datasets/formatters.py
@@ -12,7 +12,7 @@ from tqdm import tqdm
########################
-def tweb(root_path, meta_file):
+def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalize TWEB dataset.
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
"""
@@ -28,7 +28,7 @@ def tweb(root_path, meta_file):
return items
-def mozilla(root_path, meta_file):
+def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@@ -43,7 +43,7 @@ def mozilla(root_path, meta_file):
return items
-def mozilla_de(root_path, meta_file):
+def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@@ -59,7 +59,7 @@ def mozilla_de(root_path, meta_file):
return items
-def mailabs(root_path, meta_files=None):
+def mailabs(root_path, meta_files=None, ignored_speakers=None):
"""Normalizes M-AI-Labs meta data files to TTS format
Args:
@@ -68,25 +68,34 @@ def mailabs(root_path, meta_files=None):
recursively. Defaults to None
"""
speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/")
- if meta_files is None:
+ if not meta_files:
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
else:
csv_files = meta_files
+
# meta_files = [f.strip() for f in meta_files.split(",")]
items = []
for csv_file in csv_files:
- txt_file = os.path.join(root_path, csv_file)
+ if os.path.isfile(csv_file):
+ txt_file = csv_file
+ else:
+ txt_file = os.path.join(root_path, csv_file)
+
folder = os.path.dirname(txt_file)
# determine speaker based on folder structure...
speaker_name_match = speaker_regex.search(txt_file)
if speaker_name_match is None:
continue
speaker_name = speaker_name_match.group("speaker_name")
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_name in ignored_speakers:
+ continue
print(" | > {}".format(csv_file))
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
- if meta_files is None:
+ if not meta_files:
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
else:
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
@@ -94,11 +103,12 @@ def mailabs(root_path, meta_files=None):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
else:
- raise RuntimeError("> File %s does not exist!" % (wav_file))
+ # M-AI-Labs have some missing samples, so just print the warning
+ print("> File %s does not exist!" % (wav_file))
return items
-def ljspeech(root_path, meta_file):
+def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the LJSpeech meta data file to TTS format
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
@@ -113,7 +123,7 @@ def ljspeech(root_path, meta_file):
return items
-def ljspeech_test(root_path, meta_file):
+def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the LJSpeech meta data file for TTS testing
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
@@ -127,7 +137,7 @@ def ljspeech_test(root_path, meta_file):
return items
-def sam_accenture(root_path, meta_file):
+def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the sam-accenture meta data file to TTS format
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
@@ -144,12 +154,12 @@ def sam_accenture(root_path, meta_file):
return items
-def ruslan(root_path, meta_file):
+def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the RUSLAN meta data file to TTS format
https://ruslan-corpus.github.io/"""
txt_file = os.path.join(root_path, meta_file)
items = []
- speaker_name = "ljspeech"
+ speaker_name = "ruslan"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@@ -159,11 +169,11 @@ def ruslan(root_path, meta_file):
return items
-def css10(root_path, meta_file):
+def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the CSS10 dataset file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
- speaker_name = "ljspeech"
+ speaker_name = "css10"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@@ -173,7 +183,7 @@ def css10(root_path, meta_file):
return items
-def nancy(root_path, meta_file):
+def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@@ -187,7 +197,7 @@ def nancy(root_path, meta_file):
return items
-def common_voice(root_path, meta_file):
+def common_voice(root_path, meta_file, ignored_speakers=None):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
@@ -198,15 +208,19 @@ def common_voice(root_path, meta_file):
cols = line.split("\t")
text = cols[2]
speaker_name = cols[0]
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_name in ignored_speakers:
+ continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name])
return items
-def libri_tts(root_path, meta_files=None):
+def libri_tts(root_path, meta_files=None, ignored_speakers=None):
"""https://ai.google/tools/datasets/libri-tts/"""
items = []
- if meta_files is None:
+ if not meta_files:
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
else:
if isinstance(meta_files, str):
@@ -222,13 +236,17 @@ def libri_tts(root_path, meta_files=None):
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
wav_file = os.path.join(_root_path, file_name + ".wav")
text = cols[2]
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_name in ignored_speakers:
+ continue
items.append([text, wav_file, "LTTS_" + speaker_name])
for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items
-def custom_turkish(root_path, meta_file):
+def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "turkish-female"
@@ -247,7 +265,7 @@ def custom_turkish(root_path, meta_file):
# ToDo: add the dataset link when the dataset is released publicly
-def brspeech(root_path, meta_file):
+def brspeech(root_path, meta_file, ignored_speakers=None):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
@@ -258,21 +276,25 @@ def brspeech(root_path, meta_file):
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[2]
- speaker_name = cols[3]
- items.append([text, wav_file, speaker_name])
+ speaker_id = cols[3]
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_id in ignored_speakers:
+ continue
+ items.append([text, wav_file, speaker_id])
return items
-def vctk(root_path, meta_files=None, wavs_path="wav48"):
+def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
- test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
- if isinstance(test_speakers, list): # if is list ignore this speakers ids
- if speaker_id in test_speakers:
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_id in ignored_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
@@ -282,15 +304,16 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
return items
-def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
+def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for text_file in txt_files:
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
- if isinstance(meta_files, list): # if is list ignore this speakers ids
- if speaker_id in meta_files:
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker_id in ignored_speakers:
continue
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([None, wav_file, "VCTK_" + speaker_id])
@@ -298,7 +321,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
return items
-def mls(root_path, meta_files=None):
+def mls(root_path, meta_files=None, ignored_speakers=None):
"""http://www.openslr.org/94/"""
items = []
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
@@ -307,19 +330,23 @@ def mls(root_path, meta_files=None):
text = text[:-1]
speaker, book, *_ = file.split("_")
wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav")
+ # ignore speakers
+ if isinstance(ignored_speakers, list):
+ if speaker in ignored_speakers:
+ continue
items.append([text, wav_file, "MLS_" + speaker])
return items
# ======================================== VOX CELEB ===========================================
-def voxceleb2(root_path, meta_file=None):
+def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
"""
:param meta_file Used only for consistency with load_tts_samples api
"""
return _voxcel_x(root_path, meta_file, voxcel_idx="2")
-def voxceleb1(root_path, meta_file=None):
+def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
"""
:param meta_file Used only for consistency with load_tts_samples api
"""
@@ -361,7 +388,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
return [x.strip().split("|") for x in f.readlines()]
-def baker(root_path: str, meta_file: str) -> List[List[str]]:
+def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
"""Normalizes the Baker meta data file to TTS format
Args:
@@ -381,7 +408,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
return items
-def kokoro(root_path, meta_file):
+def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
txt_file = os.path.join(root_path, meta_file)
items = []
diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py
index 2c0303be..e766ed6a 100644
--- a/TTS/tts/layers/glow_tts/duration_predictor.py
+++ b/TTS/tts/layers/glow_tts/duration_predictor.py
@@ -18,8 +18,13 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer.
"""
- def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
+ def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
super().__init__()
+
+ # add language embedding dim in the input
+ if language_emb_dim:
+ in_channels += language_emb_dim
+
# class arguments
self.in_channels = in_channels
self.filter_channels = hidden_channels
@@ -36,7 +41,10 @@ class DurationPredictor(nn.Module):
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
- def forward(self, x, x_mask, g=None):
+ if language_emb_dim != 0 and language_emb_dim is not None:
+ self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
+
+ def forward(self, x, x_mask, g=None, lang_emb=None):
"""
Shapes:
- x: :math:`[B, C, T]`
@@ -45,6 +53,10 @@ class DurationPredictor(nn.Module):
"""
if g is not None:
x = x + self.cond(g)
+
+ if lang_emb is not None:
+ x = x + self.cond_lang(lang_emb)
+
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py
index 0ea342e8..7de45041 100644
--- a/TTS/tts/layers/losses.py
+++ b/TTS/tts/layers/losses.py
@@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
+ self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
@@ -585,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
l = kl / torch.sum(z_mask)
return l
+ @staticmethod
+ def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
+ l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
+ return l
+
def forward(
self,
waveform,
@@ -598,6 +604,9 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_fake,
feats_disc_real,
loss_duration,
+ use_speaker_encoder_as_loss=False,
+ gt_spk_emb=None,
+ syn_spk_emb=None,
):
"""
Shapes:
@@ -618,13 +627,20 @@ class VitsGeneratorLoss(nn.Module):
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
+
# compute losses
+ loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
- loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
+
+ if use_speaker_encoder_as_loss:
+ loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
+ loss += loss_se
+ return_dict["loss_spk_encoder"] = loss_se
+
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py
index cfc8b6ac..ef426ace 100644
--- a/TTS/tts/layers/vits/networks.py
+++ b/TTS/tts/layers/vits/networks.py
@@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
num_layers: int,
kernel_size: int,
dropout_p: float,
+ language_emb_dim: int = None,
):
"""Text Encoder for VITS model.
@@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
+ if language_emb_dim:
+ hidden_channels += language_emb_dim
+
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
@@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
- def forward(self, x, x_lengths):
+ def forward(self, x, x_lengths, lang_emb=None):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+
+ # concat the lang emb in embedding chars
+ if lang_emb is not None:
+ x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
+
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py
index 91e53da3..120d0944 100644
--- a/TTS/tts/layers/vits/stochastic_duration_predictor.py
+++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py
@@ -178,10 +178,21 @@ class StochasticDurationPredictor(nn.Module):
"""
def __init__(
- self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
+ self,
+ in_channels: int,
+ hidden_channels: int,
+ kernel_size: int,
+ dropout_p: float,
+ num_flows=4,
+ cond_channels=0,
+ language_emb_dim=0,
):
super().__init__()
+ # add language embedding dim in the input
+ if language_emb_dim:
+ in_channels += language_emb_dim
+
# condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
@@ -205,7 +216,10 @@ class StochasticDurationPredictor(nn.Module):
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
- def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
+ if language_emb_dim != 0 and language_emb_dim is not None:
+ self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
+
+ def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@@ -217,6 +231,10 @@ class StochasticDurationPredictor(nn.Module):
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
+
+ if lang_emb is not None:
+ x = x + self.cond_lang(lang_emb)
+
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py
index 780f22cd..4cc8b658 100644
--- a/TTS/tts/models/__init__.py
+++ b/TTS/tts/models/__init__.py
@@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
-def setup_model(config, speaker_manager: "SpeakerManager" = None):
+def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
@@ -31,7 +31,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
- model = MyModel(config, speaker_manager=speaker_manager)
+ if config.model.lower() in ["vits"]: # If model supports multiple languages
+ model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
+ else:
+ model = MyModel(config, speaker_manager=speaker_manager)
return model
diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py
index 854526de..e52cd765 100644
--- a/TTS/tts/models/base_tts.py
+++ b/TTS/tts/models/base_tts.py
@@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
-from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
+from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
+from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
- def init_multispeaker(self, config: Coqpit):
- """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
- vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
+ def init_multispeaker(self, config: Coqpit, data: List = None):
+ """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
+ `in_channels` size of the connected layers.
+
+ This implementation yields 3 possible outcomes:
+
+ 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
+ 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
+ 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
+ `config.d_vector_dim` or 512.
+
+ You can override this function for new models.
Args:
config (Coqpit): Model configuration.
@@ -97,6 +107,57 @@ class BaseTTS(BaseModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
+ def get_aux_input(self, **kwargs) -> Dict:
+ """Prepare and return `aux_input` used by `forward()`"""
+ return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
+
+ def get_aux_input_from_test_setences(self, sentence_info):
+ if hasattr(self.config, "model_args"):
+ config = self.config.model_args
+ else:
+ config = self.config
+
+ # extract speaker and language info
+ text, speaker_name, style_wav, language_name = None, None, None, None
+
+ if isinstance(sentence_info, list):
+ if len(sentence_info) == 1:
+ text = sentence_info[0]
+ elif len(sentence_info) == 2:
+ text, speaker_name = sentence_info
+ elif len(sentence_info) == 3:
+ text, speaker_name, style_wav = sentence_info
+ elif len(sentence_info) == 4:
+ text, speaker_name, style_wav, language_name = sentence_info
+ else:
+ text = sentence_info
+
+ # get speaker id/d_vector
+ speaker_id, d_vector, language_id = None, None, None
+ if hasattr(self, "speaker_manager"):
+ if config.use_d_vector_file:
+ if speaker_name is None:
+ d_vector = self.speaker_manager.get_random_d_vector()
+ else:
+ d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
+ elif config.use_speaker_embedding:
+ if speaker_name is None:
+ speaker_id = self.speaker_manager.get_random_speaker_id()
+ else:
+ speaker_id = self.speaker_manager.speaker_ids[speaker_name]
+
+ # get language id
+ if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
+ language_id = self.language_manager.language_id_mapping[language_name]
+
+ return {
+ "text": text,
+ "speaker_id": speaker_id,
+ "style_wav": style_wav,
+ "d_vector": d_vector,
+ "language_id": language_id,
+ }
+
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
@@ -122,6 +183,7 @@ class BaseTTS(BaseModel):
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
+ language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@@ -169,6 +231,7 @@ class BaseTTS(BaseModel):
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
+ "language_ids": language_ids,
}
def get_data_loader(
@@ -188,8 +251,15 @@ class BaseTTS(BaseModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
- speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
- d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
+ if hasattr(config, "model_args"):
+ speaker_id_mapping = (
+ self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
+ )
+ d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
+ config.use_d_vector_file = config.model_args.use_d_vector_file
+ else:
+ speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
+ d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None
@@ -199,7 +269,14 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
- # init dataset
+ if hasattr(self, "language_manager"):
+ language_id_mapping = (
+ self.language_manager.language_id_mapping if self.args.use_language_embedding else None
+ )
+ else:
+ language_id_mapping = None
+
+ # init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
@@ -222,7 +299,8 @@ class BaseTTS(BaseModel):
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
- d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
+ d_vector_mapping=d_vector_mapping,
+ language_id_mapping=language_id_mapping,
)
# pre-compute phonemes
@@ -268,7 +346,22 @@ class BaseTTS(BaseModel):
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
- # init dataloader
+ # Weighted samplers
+ assert not (
+ num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
+ ), "language_weighted_sampler is not supported with DistributedSampler"
+ assert not (
+ num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
+ ), "speaker_weighted_sampler is not supported with DistributedSampler"
+
+ if sampler is None:
+ if getattr(config, "use_language_weighted_sampler", False):
+ print(" > Using Language weighted sampler")
+ sampler = get_language_weighted_sampler(dataset.items)
+ elif getattr(config, "use_speaker_weighted_sampler", False):
+ print(" > Using Language weighted sampler")
+ sampler = get_speaker_weighted_sampler(dataset.items)
+
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
@@ -340,8 +433,7 @@ class BaseTTS(BaseModel):
return test_figures, test_audios
def on_init_start(self, trainer):
- """Save the speaker.json at the beginning of the training. And update the config.json with the
- speakers.json file path."""
+ """Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
@@ -352,3 +444,13 @@ class BaseTTS(BaseModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
+
+ if hasattr(self, "language_manager") and self.language_manager is not None:
+ output_path = os.path.join(trainer.output_path, "language_ids.json")
+ self.language_manager.save_language_ids_to_file(output_path)
+ trainer.config.language_ids_file = output_path
+ if hasattr(trainer.config, "model_args"):
+ trainer.config.model_args.language_ids_file = output_path
+ trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
+ print(f" > `language_ids.json` is saved to {output_path}.")
+ print(" > `language_ids_file` is updated in the config.json.")
diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py
index bc459b7f..b2e4be9e 100644
--- a/TTS/tts/models/vits.py
+++ b/TTS/tts/models/vits.py
@@ -1,13 +1,15 @@
import math
-import random
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple
import torch
+
+import torchaudio
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
+from torch.nn import functional as F
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
@@ -15,6 +17,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
+from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
@@ -138,11 +141,50 @@ class VitsArgs(Coqpit):
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
+ d_vector_file (str):
+ Path to the file including pre-computed speaker embeddings. Defaults to None.
+
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
detach_dp_input (bool):
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
+
+ use_language_embedding (bool):
+ Enable/Disable language embedding for multilingual models. Defaults to False.
+
+ embedded_language_dim (int):
+ Number of language embedding channels. Defaults to 4.
+
+ num_languages (int):
+ Number of languages for the language embedding layer. Defaults to 0.
+
+ language_ids_file (str):
+ Path to the language mapping file for the Language Manager. Defaults to None.
+
+ use_speaker_encoder_as_loss (bool):
+ Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
+
+ speaker_encoder_config_path (str):
+ Path to the file speaker encoder config file, to use for SCL. Defaults to "".
+
+ speaker_encoder_model_path (str):
+ Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
+
+ freeze_encoder (bool):
+ Freeze the encoder weigths during training. Defaults to False.
+
+ freeze_DP (bool):
+ Freeze the duration predictor weigths during training. Defaults to False.
+
+ freeze_PE (bool):
+ Freeze the posterior encoder weigths during training. Defaults to False.
+
+ freeze_flow_encoder (bool):
+ Freeze the flow encoder weigths during training. Defaults to False.
+
+ freeze_waveform_decoder (bool):
+ Freeze the waveform decoder weigths during training. Defaults to False.
"""
num_chars: int = 100
@@ -179,11 +221,23 @@ class VitsArgs(Coqpit):
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
+ d_vector_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
- d_vector_file: str = None
d_vector_dim: int = 0
detach_dp_input: bool = True
+ use_language_embedding: bool = False
+ embedded_language_dim: int = 4
+ num_languages: int = 0
+ language_ids_file: str = None
+ use_speaker_encoder_as_loss: bool = False
+ speaker_encoder_config_path: str = ""
+ speaker_encoder_model_path: str = ""
+ freeze_encoder: bool = False
+ freeze_DP: bool = False
+ freeze_PE: bool = False
+ freeze_flow_decoder: bool = False
+ freeze_waveform_decoder: bool = False
class Vits(BaseTTS):
@@ -216,13 +270,18 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value
- def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
+ def __init__(
+ self,
+ config: Coqpit,
+ speaker_manager: SpeakerManager = None,
+ language_manager: LanguageManager = None,
+ ):
super().__init__(config)
self.END2END = True
-
self.speaker_manager = speaker_manager
+ self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
@@ -242,6 +301,7 @@ class Vits(BaseTTS):
self.args = args
self.init_multispeaker(config)
+ self.init_multilingual(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
@@ -260,6 +320,7 @@ class Vits(BaseTTS):
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
+ language_emb_dim=self.embedded_language_dim,
)
self.posterior_encoder = PosteriorEncoder(
@@ -289,10 +350,16 @@ class Vits(BaseTTS):
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
+ language_emb_dim=self.embedded_language_dim,
)
else:
self.duration_predictor = DurationPredictor(
- args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
+ args.hidden_channels,
+ 256,
+ 3,
+ args.dropout_p_duration_predictor,
+ cond_channels=self.embedded_speaker_dim,
+ language_emb_dim=self.embedded_language_dim,
)
self.waveform_decoder = HifiganGenerator(
@@ -318,54 +385,150 @@ class Vits(BaseTTS):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
+ You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
+
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
- if hasattr(config, "model_args"):
- config = config.model_args
+ self.num_speakers = self.args.num_speakers
- self.num_speakers = config.num_speakers
+ if self.speaker_manager:
+ self.num_speakers = self.speaker_manager.num_speakers
- if config.use_speaker_embedding:
- self._init_speaker_embedding(config)
+ if self.args.use_speaker_embedding:
+ self._init_speaker_embedding()
- if config.use_d_vector_file:
- self._init_d_vector(config)
+ if self.args.use_d_vector_file:
+ self._init_d_vector()
- def _init_speaker_embedding(self, config):
+ # TODO: make this a function
+ if self.args.use_speaker_encoder_as_loss:
+ if self.speaker_manager.speaker_encoder is None and (
+ not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
+ ):
+ raise RuntimeError(
+ " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
+ )
+
+ self.speaker_manager.speaker_encoder.eval()
+ print(" > External Speaker Encoder Loaded !!")
+
+ if (
+ hasattr(self.speaker_manager.speaker_encoder, "audio_config")
+ and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
+ ):
+ self.audio_transform = torchaudio.transforms.Resample(
+ orig_freq=self.audio_config["sample_rate"],
+ new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
+ )
+ else:
+ self.audio_transform = None
+
+ def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
- if config.speakers_file is not None:
- self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
-
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
- self.embedded_speaker_dim = config.speaker_embedding_channels
+ self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
- def _init_d_vector(self, config):
+ def _init_d_vector(self):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
- self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
- self.embedded_speaker_dim = config.d_vector_dim
+ self.embedded_speaker_dim = self.args.d_vector_dim
+
+ def init_multilingual(self, config: Coqpit):
+ """Initialize multilingual modules of a model.
+
+ Args:
+ config (Coqpit): Model configuration.
+ """
+ if self.args.language_ids_file is not None:
+ self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
+
+ if self.args.use_language_embedding and self.language_manager:
+ print(" > initialization of language-embedding layers.")
+ self.num_languages = self.language_manager.num_languages
+ self.embedded_language_dim = self.args.embedded_language_dim
+ self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
+ torch.nn.init.xavier_uniform_(self.emb_l.weight)
+ else:
+ self.embedded_language_dim = 0
+ self.emb_l = None
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
- sid, g = None, None
+ sid, g, lid = None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
- g = aux_input["d_vectors"]
- return sid, g
+ g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
+ if g.ndim == 2:
+ g = g.unsqueeze_(0)
+
+ if "language_ids" in aux_input and aux_input["language_ids"] is not None:
+ lid = aux_input["language_ids"]
+ if lid.ndim == 0:
+ lid = lid.unsqueeze_(0)
+
+ return sid, g, lid
def get_aux_input(self, aux_input: Dict):
- sid, g = self._set_cond_input(aux_input)
- return {"speaker_id": sid, "style_wav": None, "d_vector": g}
+ sid, g, lid = self._set_cond_input(aux_input)
+ return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
+
+ def get_aux_input_from_test_sentences(self, sentence_info):
+ if hasattr(self.config, "model_args"):
+ config = self.config.model_args
+ else:
+ config = self.config
+
+ # extract speaker and language info
+ text, speaker_name, style_wav, language_name = None, None, None, None
+
+ if isinstance(sentence_info, list):
+ if len(sentence_info) == 1:
+ text = sentence_info[0]
+ elif len(sentence_info) == 2:
+ text, speaker_name = sentence_info
+ elif len(sentence_info) == 3:
+ text, speaker_name, style_wav = sentence_info
+ elif len(sentence_info) == 4:
+ text, speaker_name, style_wav, language_name = sentence_info
+ else:
+ text = sentence_info
+
+ # get speaker id/d_vector
+ speaker_id, d_vector, language_id = None, None, None
+ if hasattr(self, "speaker_manager"):
+ if config.use_d_vector_file:
+ if speaker_name is None:
+ d_vector = self.speaker_manager.get_random_d_vector()
+ else:
+ d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False)
+ elif config.use_speaker_embedding:
+ if speaker_name is None:
+ speaker_id = self.speaker_manager.get_random_speaker_id()
+ else:
+ speaker_id = self.speaker_manager.speaker_ids[speaker_name]
+
+ # get language id
+ if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
+ language_id = self.language_manager.language_id_mapping[language_name]
+
+ return {
+ "text": text,
+ "speaker_id": speaker_id,
+ "style_wav": style_wav,
+ "d_vector": d_vector,
+ "language_id": language_id,
+ "language_name": language_name,
+ }
def forward(
self,
@@ -373,7 +536,8 @@ class Vits(BaseTTS):
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
- aux_input={"d_vectors": None, "speaker_ids": None},
+ waveform: torch.tensor,
+ aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
) -> Dict:
"""Forward pass of the model.
@@ -382,7 +546,9 @@ class Vits(BaseTTS):
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
- aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
+ waveform (torch.tensor): Batch of ground truth waveforms per sample.
+ aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
+ Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
@@ -392,17 +558,24 @@ class Vits(BaseTTS):
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
+ - waveform: :math:`[B, T_wav, 1]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
+ - language_ids: :math:`[B]`
"""
outputs = {}
- sid, g = self._set_cond_input(aux_input)
- x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
-
+ sid, g, lid = self._set_cond_input(aux_input)
# speaker embedding
- if self.num_speakers > 1 and sid is not None:
+ if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+ # language embedding
+ lang_emb = None
+ if self.args.use_language_embedding and lid is not None:
+ lang_emb = self.emb_l(lid).unsqueeze(-1)
+
+ x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
+
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@@ -428,6 +601,7 @@ class Vits(BaseTTS):
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
+ lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
@@ -436,6 +610,7 @@ class Vits(BaseTTS):
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
+ lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
@@ -447,40 +622,73 @@ class Vits(BaseTTS):
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
+
+ wav_seg = segment(
+ waveform,
+ slice_ids * self.config.audio.hop_length,
+ self.args.spec_segment_size * self.config.audio.hop_length,
+ )
+
+ if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
+ # concate generated and GT waveforms
+ wavs_batch = torch.cat((wav_seg, o), dim=0)
+
+ # resample audio to speaker encoder sample_rate
+ # pylint: disable=W0105
+ if self.audio_transform is not None:
+ wavs_batch = self.audio_transform(wavs_batch)
+
+ pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
+
+ # split generated and GT speaker embeddings
+ gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
+ else:
+ gt_spk_emb, syn_spk_emb = None, None
+
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
- "slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
+ "waveform_seg": wav_seg,
+ "gt_spk_emb": gt_spk_emb,
+ "syn_spk_emb": syn_spk_emb,
}
)
return outputs
- def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
+ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
- sid, g = self._set_cond_input(aux_input)
+ sid, g, lid = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
- x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
-
- if self.num_speakers > 0 and sid is not None:
+ # speaker embedding
+ if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1)
+ # language embedding
+ lang_emb = None
+ if self.args.use_language_embedding and lid is not None:
+ lang_emb = self.emb_l(lid).unsqueeze(-1)
+
+ x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
+
if self.args.use_sdp:
- logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
+ logw = self.duration_predictor(
+ x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
+ )
else:
- logw = self.duration_predictor(x, x_mask, g=g)
+ logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
@@ -499,12 +707,30 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
- def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
- """TODO: create an end-point for voice conversion"""
+ def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
+ """Forward pass for voice conversion
+
+ TODO: create an end-point for voice conversion
+
+ Args:
+ y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
+ y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
+ speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
+ speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
+ """
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
- g_src = self.emb_g(sid_src).unsqueeze(-1)
- g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
- z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
+
+ # speaker embedding
+ if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
+ g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
+ g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
+ elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
+ g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
+ g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
+ else:
+ raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
+
+ z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
@@ -525,6 +751,30 @@ class Vits(BaseTTS):
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
+ if self.args.freeze_encoder:
+ for param in self.text_encoder.parameters():
+ param.requires_grad = False
+
+ if hasattr(self, "emb_l"):
+ for param in self.emb_l.parameters():
+ param.requires_grad = False
+
+ if self.args.freeze_PE:
+ for param in self.posterior_encoder.parameters():
+ param.requires_grad = False
+
+ if self.args.freeze_DP:
+ for param in self.duration_predictor.parameters():
+ param.requires_grad = False
+
+ if self.args.freeze_flow_decoder:
+ for param in self.flow.parameters():
+ param.requires_grad = False
+
+ if self.args.freeze_waveform_decoder:
+ for param in self.waveform_decoder.parameters():
+ param.requires_grad = False
+
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
@@ -532,6 +782,7 @@ class Vits(BaseTTS):
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
+ language_ids = batch["language_ids"]
waveform = batch["waveform"]
# generator pass
@@ -540,31 +791,26 @@ class Vits(BaseTTS):
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
- aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
+ waveform.transpose(1, 2),
+ aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
- wav_seg = segment(
- waveform.transpose(1, 2),
- outputs["slice_ids"] * self.config.audio.hop_length,
- self.args.spec_segment_size * self.config.audio.hop_length,
- )
- self.wav_seg_disc_cache = wav_seg
- outputs["waveform_seg"] = wav_seg
+ self.wav_seg_disc_cache = outputs["waveform_seg"]
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
- outputs["model_outputs"], wav_seg
+ outputs["model_outputs"], outputs["waveform_seg"]
)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
- waveform=wav_seg.float(),
+ waveform=outputs["waveform_seg"].float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
@@ -574,6 +820,9 @@ class Vits(BaseTTS):
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"],
+ use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
+ gt_spk_emb=outputs["gt_spk_emb"],
+ syn_spk_emb=outputs["syn_spk_emb"],
)
elif optimizer_idx == 1:
@@ -651,32 +900,28 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
- aux_inputs = {
- "speaker_id": None
- if not self.config.use_speaker_embedding
- else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
- "d_vector": None
- if not self.config.use_d_vector_file
- else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
- "style_wav": None,
- }
- for idx, sen in enumerate(test_sentences):
- wav, alignment, _, _ = synthesis(
- self,
- sen,
- self.config,
- "cuda" in str(next(self.parameters()).device),
- ap,
- speaker_id=aux_inputs["speaker_id"],
- d_vector=aux_inputs["d_vector"],
- style_wav=aux_inputs["style_wav"],
- enable_eos_bos_chars=self.config.enable_eos_bos_chars,
- use_griffin_lim=True,
- do_trim_silence=False,
- ).values()
-
- test_audios["{}-audio".format(idx)] = wav
- test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
+ for idx, s_info in enumerate(test_sentences):
+ try:
+ aux_inputs = self.get_aux_input_from_test_sentences(s_info)
+ wav, alignment, _, _ = synthesis(
+ self,
+ aux_inputs["text"],
+ self.config,
+ "cuda" in str(next(self.parameters()).device),
+ ap,
+ speaker_id=aux_inputs["speaker_id"],
+ d_vector=aux_inputs["d_vector"],
+ style_wav=aux_inputs["style_wav"],
+ language_id=aux_inputs["language_id"],
+ language_name=aux_inputs["language_name"],
+ enable_eos_bos_chars=self.config.enable_eos_bos_chars,
+ use_griffin_lim=True,
+ do_trim_silence=False,
+ ).values()
+ test_audios["{}-audio".format(idx)] = wav
+ test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
+ except: # pylint: disable=bare-except
+ print(" !! Error creating Test Sentence -", idx)
return test_figures, test_audios
def get_optimizer(self) -> List:
@@ -695,8 +940,12 @@ class Vits(BaseTTS):
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
- if hasattr(self, "emb_g"):
+ if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file:
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
+ # add the language embedding layer
+ if hasattr(self, "emb_l") and self.args.use_language_embedding:
+ gen_parameters = chain(gen_parameters, self.emb_l.parameters())
+
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
@@ -769,6 +1018,10 @@ class Vits(BaseTTS):
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ # compat band-aid for the pre-trained models to not use the encoder baked into the model
+ # TODO: consider baking the speaker encoder into the model and call it from there.
+ # as it is probably easier for model distribution.
+ state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py
new file mode 100644
index 00000000..fc7eec57
--- /dev/null
+++ b/TTS/tts/utils/languages.py
@@ -0,0 +1,122 @@
+import json
+import os
+from typing import Dict, List
+
+import fsspec
+import numpy as np
+import torch
+from coqpit import Coqpit
+from torch.utils.data.sampler import WeightedRandomSampler
+
+
+class LanguageManager:
+ """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
+ in a way that can be queried by language.
+
+ Args:
+ language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
+ TTS models. Defaults to "".
+ config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
+ Defaults to None.
+
+ Examples:
+ >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
+ >>> language_id_mapper = manager.language_ids
+ """
+
+ language_id_mapping: Dict = {}
+
+ def __init__(
+ self,
+ language_ids_file_path: str = "",
+ config: Coqpit = None,
+ ):
+ self.language_id_mapping = {}
+ if language_ids_file_path:
+ self.set_language_ids_from_file(language_ids_file_path)
+
+ if config:
+ self.set_language_ids_from_config(config)
+
+ @staticmethod
+ def _load_json(json_file_path: str) -> Dict:
+ with fsspec.open(json_file_path, "r") as f:
+ return json.load(f)
+
+ @staticmethod
+ def _save_json(json_file_path: str, data: dict) -> None:
+ with fsspec.open(json_file_path, "w") as f:
+ json.dump(data, f, indent=4)
+
+ @property
+ def num_languages(self) -> int:
+ return len(list(self.language_id_mapping.keys()))
+
+ @property
+ def language_names(self) -> List:
+ return list(self.language_id_mapping.keys())
+
+ @staticmethod
+ def parse_language_ids_from_config(c: Coqpit) -> Dict:
+ """Set language id from config.
+
+ Args:
+ c (Coqpit): Config
+
+ Returns:
+ Tuple[Dict, int]: Language ID mapping and the number of languages.
+ """
+ languages = set({})
+ for dataset in c.datasets:
+ if "language" in dataset:
+ languages.add(dataset["language"])
+ else:
+ raise ValueError(f"Dataset {dataset['name']} has no language specified.")
+ return {name: i for i, name in enumerate(sorted(list(languages)))}
+
+ def set_language_ids_from_config(self, c: Coqpit) -> None:
+ """Set language IDs from config samples.
+
+ Args:
+ items (List): Data sampled returned by `load_meta_data()`.
+ """
+ self.language_id_mapping = self.parse_language_ids_from_config(c)
+
+ def set_language_ids_from_file(self, file_path: str) -> None:
+ """Load language ids from a json file.
+
+ Args:
+ file_path (str): Path to the target json file.
+ """
+ self.language_id_mapping = self._load_json(file_path)
+
+ def save_language_ids_to_file(self, file_path: str) -> None:
+ """Save language IDs to a json file.
+
+ Args:
+ file_path (str): Path to the output file.
+ """
+ self._save_json(file_path, self.language_id_mapping)
+
+
+def _set_file_path(path):
+ """Find the language_ids.json under the given path or the above it.
+ Intended to band aid the different paths returned in restored and continued training."""
+ path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
+ path_continue = os.path.join(path, "language_ids.json")
+ fs = fsspec.get_mapper(path).fs
+ if fs.exists(path_restore):
+ return path_restore
+ if fs.exists(path_continue):
+ return path_continue
+ return None
+
+
+def get_language_weighted_sampler(items: list):
+ language_names = np.array([item[3] for item in items])
+ unique_language_names = np.unique(language_names).tolist()
+ language_ids = [unique_language_names.index(l) for l in language_names]
+ language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
+ weight_language = 1.0 / language_count
+ dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
+ return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py
index 13696a20..07076d90 100644
--- a/TTS/tts/utils/speakers.py
+++ b/TTS/tts/utils/speakers.py
@@ -7,9 +7,10 @@ import fsspec
import numpy as np
import torch
from coqpit import Coqpit
+from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import load_config
-from TTS.speaker_encoder.utils.generic_utils import setup_model
+from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
@@ -161,8 +162,10 @@ class SpeakerManager:
file_path (str): Path to the target json file.
"""
self.d_vectors = self._load_json(file_path)
+
speakers = sorted({x["name"] for x in self.d_vectors.values()})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
+
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
def get_d_vector_by_clip(self, clip_idx: str) -> List:
@@ -209,6 +212,32 @@ class SpeakerManager:
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
return d_vectors
+ def get_random_speaker_id(self) -> Any:
+ """Get a random d_vector.
+
+ Args:
+
+ Returns:
+ np.ndarray: d_vector.
+ """
+ if self.speaker_ids:
+ return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
+
+ return None
+
+ def get_random_d_vector(self) -> Any:
+ """Get a random D ID.
+
+ Args:
+
+ Returns:
+ np.ndarray: d_vector.
+ """
+ if self.d_vectors:
+ return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
+
+ return None
+
def get_speakers(self) -> List:
return self.speaker_ids
@@ -223,18 +252,15 @@ class SpeakerManager:
config_path (str): Model config file path.
"""
self.speaker_encoder_config = load_config(config_path)
- self.speaker_encoder = setup_model(self.speaker_encoder_config)
+ self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
- # normalize the input audio level and trim silences
- # self.speaker_encoder_ap.do_sound_norm = True
- # self.speaker_encoder_ap.do_trim_silence = True
- def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
+ def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a d_vector from a given audio file.
Args:
- wav_file (Union[str, list]): Target file path.
+ wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed d_vector.
@@ -242,12 +268,16 @@ class SpeakerManager:
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
- spec = self.speaker_encoder_ap.melspectrogram(waveform)
- spec = torch.from_numpy(spec.T)
+ if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
+ m_input = self.speaker_encoder_ap.melspectrogram(waveform)
+ m_input = torch.from_numpy(m_input)
+ else:
+ m_input = torch.from_numpy(waveform)
+
if self.use_cuda:
- spec = spec.cuda()
- spec = spec.unsqueeze(0)
- d_vector = self.speaker_encoder.compute_embedding(spec)
+ m_input = m_input.cuda()
+ m_input = m_input.unsqueeze(0)
+ d_vector = self.speaker_encoder.compute_embedding(m_input)
return d_vector
if isinstance(wav_file, list):
@@ -364,11 +394,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
- print(
- " > Speaker manager is loaded with {} speakers: {}".format(
- speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
+
+ if speaker_manager.num_speakers > 0:
+ print(
+ " > Speaker manager is loaded with {} speakers: {}".format(
+ speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
+ )
)
- )
+
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speakers.json")
@@ -378,3 +411,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
else:
speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager
+
+
+def get_speaker_weighted_sampler(items: list):
+ speaker_names = np.array([item[2] for item in items])
+ unique_speaker_names = np.unique(speaker_names).tolist()
+ speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
+ speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
+ weight_speaker = 1.0 / speaker_count
+ dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
+ return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py
index 578c26c0..24b747be 100644
--- a/TTS/tts/utils/synthesis.py
+++ b/TTS/tts/utils/synthesis.py
@@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
-def text_to_seq(text, CONFIG, custom_symbols=None):
+def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
@@ -23,7 +23,7 @@ def text_to_seq(text, CONFIG, custom_symbols=None):
phoneme_to_sequence(
text,
text_cleaner,
- CONFIG.phoneme_language,
+ language if language else CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
@@ -71,6 +71,7 @@ def run_model_torch(
speaker_id: int = None,
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
+ language_id: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@@ -96,6 +97,7 @@ def run_model_torch(
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
+ "language_ids": language_id,
},
)
return outputs
@@ -160,19 +162,20 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
-def speaker_id_to_torch(speaker_id, cuda=False):
- if speaker_id is not None:
- speaker_id = np.asarray(speaker_id)
- speaker_id = torch.from_numpy(speaker_id)
+def id_to_torch(aux_id, cuda=False):
+ if aux_id is not None:
+ aux_id = np.asarray(aux_id)
+ aux_id = torch.from_numpy(aux_id)
if cuda:
- return speaker_id.cuda()
- return speaker_id
+ return aux_id.cuda()
+ return aux_id
def embedding_to_torch(d_vector, cuda=False):
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
+ d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
return d_vector
@@ -208,6 +211,8 @@ def synthesis(
use_griffin_lim=False,
do_trim_silence=False,
d_vector=None,
+ language_id=None,
+ language_name=None,
backend="torch",
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
@@ -244,6 +249,12 @@ def synthesis(
d_vector (torch.Tensor):
d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.
+ language_id (int):
+ Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
+
+ language_name (str):
+ Language name corresponding to the language code used by the phonemizer. Defaults to None.
+
backend (str):
tf or torch. Defaults to "torch".
"""
@@ -258,15 +269,18 @@ def synthesis(
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text
- text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
+ text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
- speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
+ speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
+ if language_id is not None:
+ language_id = id_to_torch(language_id, cuda=use_cuda)
+
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
@@ -278,7 +292,7 @@ def synthesis(
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice
if backend == "torch":
- outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
+ outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py
index 4b041ed8..f3ffa478 100644
--- a/TTS/tts/utils/text/cleaners.py
+++ b/TTS/tts/utils/text/cleaners.py
@@ -135,3 +135,12 @@ def phoneme_cleaners(text):
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
+
+
+def multilingual_cleaners(text):
+ """Pipeline for multilingual text"""
+ text = lowercase(text)
+ text = replace_symbols(text, lang=None)
+ text = remove_aux_symbols(text)
+ text = collapse_whitespace(text)
+ return text
diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py
index e64b95e0..25f93c34 100644
--- a/TTS/utils/audio.py
+++ b/TTS/utils/audio.py
@@ -16,6 +16,60 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
+
+ Args:
+
+ n_fft (int):
+ FFT window size for STFT.
+
+ hop_length (int):
+ number of frames between STFT columns.
+
+ win_length (int, optional):
+ STFT window length.
+
+ pad_wav (bool, optional):
+ If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
+
+ window (str, optional):
+ The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
+
+ sample_rate (int, optional):
+ target audio sampling rate. Defaults to None.
+
+ mel_fmin (int, optional):
+ minimum filter frequency for computing melspectrograms. Defaults to None.
+
+ mel_fmax (int, optional):
+ maximum filter frequency for computing melspectrograms. Defaults to None.
+
+ n_mels (int, optional):
+ number of melspectrogram dimensions. Defaults to None.
+
+ use_mel (bool, optional):
+ If True compute the melspectrograms otherwise. Defaults to False.
+
+ do_amp_to_db_linear (bool, optional):
+ enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
+
+ spec_gain (float, optional):
+ gain applied when converting amplitude to DB. Defaults to 1.0.
+
+ power (float, optional):
+ Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
+
+ use_htk (bool, optional):
+ Use HTK formula in mel filter instead of Slaney.
+
+ mel_norm (None, 'slaney', or number, optional):
+ If 'slaney', divide the triangular mel weights by the width of the mel band
+ (area normalization).
+
+ If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
+ See `librosa.util.normalize` for a full description of supported norm values
+ (including `+-np.inf`).
+
+ Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
@@ -32,6 +86,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
+ power=None,
+ use_htk=False,
+ mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
@@ -45,6 +102,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
+ self.power = power
+ self.use_htk = use_htk
+ self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
@@ -83,6 +143,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
+
+ if self.power is not None:
+ S = S ** self.power
+
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
@@ -91,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
- self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
+ self.sample_rate,
+ self.n_fft,
+ n_mels=self.n_mels,
+ fmin=self.mel_fmin,
+ fmax=self.mel_fmax,
+ htk=self.use_htk,
+ norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@@ -167,7 +237,7 @@ class AudioProcessor(object):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
- maximum filter frequency for computing melspectrograms.. Defaults to None.
+ maximum filter frequency for computing melspectrograms. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
@@ -196,6 +266,12 @@ class AudioProcessor(object):
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
+ do_rms_norm (bool, optional):
+ enable/disable RMS volume normalization when loading an audio file. Defaults to False.
+
+ db_level (int, optional):
+ dB level used for rms normalization. The range is -99 to 0. Defaults to None.
+
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
@@ -233,6 +309,8 @@ class AudioProcessor(object):
do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
+ do_rms_norm=False,
+ db_level=None,
stats_path=None,
verbose=True,
**_,
@@ -264,6 +342,8 @@ class AudioProcessor(object):
self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
+ self.do_rms_norm = do_rms_norm
+ self.db_level = db_level
self.stats_path = stats_path
# setup exp_func for db to amp conversion
if log_func == "np.log":
@@ -656,21 +736,6 @@ class AudioProcessor(object):
frame_period=1000 * self.hop_length / self.sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
- # pad = int((self.win_length / self.hop_length) / 2)
- # f0 = [0.0] * pad + f0 + [0.0] * pad
- # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
- # f0 = np.array(f0, dtype=np.float32)
-
- # f01, _, _ = librosa.pyin(
- # x,
- # fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
- # fmax=self.mel_fmax,
- # frame_length=self.win_length,
- # sr=self.sample_rate,
- # fill_na=0.0,
- # )
-
- # spec = self.melspectrogram(x)
return f0
### Audio Processing ###
@@ -713,10 +778,33 @@ class AudioProcessor(object):
"""
return x / abs(x).max() * 0.95
+ @staticmethod
+ def _rms_norm(wav, db_level=-27):
+ r = 10 ** (db_level / 20)
+ a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
+ return wav * a
+
+ def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
+ """Normalize the volume based on RMS of the signal.
+
+ Args:
+ x (np.ndarray): Raw waveform.
+
+ Returns:
+ np.ndarray: RMS normalized waveform.
+ """
+ if db_level is None:
+ db_level = self.db_level
+ assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
+ wav = self._rms_norm(x, db_level)
+ return wav
+
### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
+ Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
+
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
@@ -725,8 +813,10 @@ class AudioProcessor(object):
np.ndarray: Loaded waveform.
"""
if self.resample:
+ # loading with resampling. It is significantly slower.
x, sr = librosa.load(filename, sr=self.sample_rate)
elif sr is None:
+ # SF is faster than librosa for loading files
x, sr = sf.read(filename)
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
else:
@@ -738,6 +828,8 @@ class AudioProcessor(object):
print(f" [!] File cannot be trimmed for silence - {filename}")
if self.do_sound_norm:
x = self.sound_norm(x)
+ if self.do_rms_norm:
+ x = self.rms_volume_norm(x, self.db_level)
return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
diff --git a/TTS/utils/download.py b/TTS/utils/download.py
index 5cfb69cd..241a106b 100644
--- a/TTS/utils/download.py
+++ b/TTS/utils/download.py
@@ -7,6 +7,7 @@ import tarfile
import urllib
import urllib.request
import zipfile
+from os.path import expanduser
from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
@@ -183,3 +184,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
pass
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
+
+
+def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
+ """Download dataset from kaggle.
+ Args:
+ dataset_path (str):
+ This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
+ dataset_name (str): Name of the folder the dataset will be saved in.
+ output_path (str): Path of the location you want the dataset folder to be saved to.
+ """
+ data_path = os.path.join(output_path, dataset_name)
+ try:
+ import kaggle # pylint: disable=import-outside-toplevel
+
+ kaggle.api.authenticate()
+ print(f"""\nDownloading {dataset_name}...""")
+ kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
+ except OSError:
+ print(
+ f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
+ )
diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py
index 89f2148f..104dc7b9 100644
--- a/TTS/utils/downloaders.py
+++ b/TTS/utils/downloaders.py
@@ -1,6 +1,7 @@
import os
+from typing import Optional
-from TTS.utils.download import download_url, extract_archive
+from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
def download_ljspeech(path: str):
@@ -18,14 +19,106 @@ def download_ljspeech(path: str):
extract_archive(archive)
-def download_vctk(path: str):
- """Download and extract VCTK dataset
+def download_vctk(path: str, use_kaggle: Optional[bool] = False):
+ """Download and extract VCTK dataset.
Args:
path (str): path to the directory where the dataset will be stored.
+
+ use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False.
+ """
+ if use_kaggle:
+ download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path)
+ else:
+ os.makedirs(path, exist_ok=True)
+ url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
+ download_url(url, path)
+ basename = os.path.basename(url)
+ archive = os.path.join(path, basename)
+ print(" > Extracting archive file...")
+ extract_archive(archive)
+
+
+def download_tweb(path: str):
+ """Download and extract Tweb dataset
+
+ Args:
+ path (str): Path to the directory where the dataset will be stored.
+ """
+ download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
+
+
+def download_libri_tts(path: str, subset: Optional[str] = "all"):
+ """Download and extract libri tts dataset.
+
+ Args:
+ path (str): Path to the directory where the dataset will be stored.
+
+ subset (str, optional): Name of the subset to download. If you only want to download a certain
+ portion specify it here. Defaults to 'all'.
+ """
+
+ subset_dict = {
+ "libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz",
+ "libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz",
+ "libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz",
+ "libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz",
+ "libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz",
+ "libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz",
+ "libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz",
+ }
+
+ os.makedirs(path, exist_ok=True)
+ if subset == "all":
+ for sub, val in subset_dict.items():
+ print(f" > Downloading {sub}...")
+ download_url(val, path)
+ basename = os.path.basename(val)
+ archive = os.path.join(path, basename)
+ print(" > Extracting archive file...")
+ extract_archive(archive)
+ print(" > All subsets downloaded")
+ else:
+ url = subset_dict[subset]
+ download_url(url, path)
+ basename = os.path.basename(url)
+ archive = os.path.join(path, basename)
+ print(" > Extracting archive file...")
+ extract_archive(archive)
+
+
+def download_thorsten_de(path: str):
+ """Download and extract Thorsten german male voice dataset.
+
+ Args:
+ path (str): Path to the directory where the dataset will be stored.
"""
os.makedirs(path, exist_ok=True)
- url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
+ url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz"
+ download_url(url, path)
+ basename = os.path.basename(url)
+ archive = os.path.join(path, basename)
+ print(" > Extracting archive file...")
+ extract_archive(archive)
+
+
+def download_mailabs(path: str, language: str = "english"):
+ """Download and extract Mailabs dataset.
+
+ Args:
+ path (str): Path to the directory where the dataset will be stored.
+
+ language (str): Language subset to download. Defaults to english.
+ """
+ language_dict = {
+ "english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz",
+ "german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz",
+ "french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz",
+ "italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz",
+ "spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz",
+ }
+ os.makedirs(path, exist_ok=True)
+ url = language_dict[language]
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
diff --git a/TTS/utils/io.py b/TTS/utils/io.py
index a93f6118..54818ce9 100644
--- a/TTS/utils/io.py
+++ b/TTS/utils/io.py
@@ -26,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self
-def copy_model_files(config: Coqpit, out_path, new_fields):
+def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.
diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index cfbbdff0..01d54ad6 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -46,36 +46,66 @@ class ModelManager(object):
with open(file_path, "r", encoding="utf-8") as json_file:
self.models_dict = json.load(json_file)
- def list_langs(self):
- print(" Name format: type/language")
- for model_type in self.models_dict:
- for lang in self.models_dict[model_type]:
- print(f" >: {model_type}/{lang} ")
+ def _list_models(self, model_type, model_count=0):
+ model_list = []
+ for lang in self.models_dict[model_type]:
+ for dataset in self.models_dict[model_type][lang]:
+ for model in self.models_dict[model_type][lang][dataset]:
+ model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
+ output_path = os.path.join(self.output_prefix, model_full_name)
+ if os.path.exists(output_path):
+ print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
+ else:
+ print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
+ model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
+ model_count += 1
+ return model_list
- def list_datasets(self):
- print(" Name format: type/language/dataset")
- for model_type in self.models_dict:
- for lang in self.models_dict[model_type]:
- for dataset in self.models_dict[model_type][lang]:
- print(f" >: {model_type}/{lang}/{dataset}")
+ def _list_for_model_type(self, model_type):
+ print(" Name format: language/dataset/model")
+ models_name_list = []
+ model_count = 1
+ model_type = "tts_models"
+ models_name_list.extend(self._list_models(model_type, model_count))
+ return [name.replace(model_type + "/", "") for name in models_name_list]
def list_models(self):
print(" Name format: type/language/dataset/model")
models_name_list = []
model_count = 1
+ for model_type in self.models_dict:
+ model_list = self._list_models(model_type, model_count)
+ models_name_list.extend(model_list)
+ return models_name_list
+
+ def list_tts_models(self):
+ """Print all `TTS` models and return a list of model names
+
+ Format is `language/dataset/model`
+ """
+ return self._list_for_model_type("tts_models")
+
+ def list_vocoder_models(self):
+ """Print all the `vocoder` models and return a list of model names
+
+ Format is `language/dataset/model`
+ """
+ return self._list_for_model_type("vocoder_models")
+
+ def list_langs(self):
+ """Print all the available languages"""
+ print(" Name format: type/language")
+ for model_type in self.models_dict:
+ for lang in self.models_dict[model_type]:
+ print(f" >: {model_type}/{lang} ")
+
+ def list_datasets(self):
+ """Print all the datasets"""
+ print(" Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
- for model in self.models_dict[model_type][lang][dataset]:
- model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
- output_path = os.path.join(self.output_prefix, model_full_name)
- if os.path.exists(output_path):
- print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
- else:
- print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
- models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
- model_count += 1
- return models_name_list
+ print(f" >: {model_type}/{lang}/{dataset}")
def download_model(self, model_name):
"""Download model files given the full model name.
@@ -121,6 +151,8 @@ class ModelManager(object):
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
+ speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
+ speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
@@ -133,6 +165,12 @@ class ModelManager(object):
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
+ # update the speaker_encoder file path in the model config.json to the current path
+ self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
+ self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
+ self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
+ self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
+
@staticmethod
def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download"""
@@ -159,8 +197,12 @@ class ModelManager(object):
# download the file
r = requests.get(file_url)
# extract the file
- with zipfile.ZipFile(io.BytesIO(r.content)) as z:
- z.extractall(output_folder)
+ try:
+ with zipfile.ZipFile(io.BytesIO(r.content)) as z:
+ z.extractall(output_folder)
+ except zipfile.BadZipFile:
+ print(f" > Error: Bad zip file - {file_url}")
+ raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in z.namelist()[1:]:
src_path = os.path.join(output_folder, file_path)
diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py
index 043c4982..fc45e7fa 100644
--- a/TTS/utils/synthesizer.py
+++ b/TTS/utils/synthesizer.py
@@ -1,12 +1,13 @@
import time
-from typing import List
+from typing import List, Union
import numpy as np
import pysbd
import torch
-from TTS.config import load_config
+from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
from TTS.tts.models import setup_model as setup_tts_model
+from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import
@@ -23,6 +24,7 @@ class Synthesizer(object):
tts_checkpoint: str,
tts_config_path: str,
tts_speakers_file: str = "",
+ tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
encoder_checkpoint: str = "",
@@ -52,6 +54,7 @@ class Synthesizer(object):
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
+ self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.encoder_checkpoint = encoder_checkpoint
@@ -63,6 +66,9 @@ class Synthesizer(object):
self.speaker_manager = None
self.num_speakers = 0
self.tts_speakers = {}
+ self.language_manager = None
+ self.num_languages = 0
+ self.tts_languages = {}
self.d_vector_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
@@ -110,29 +116,94 @@ class Synthesizer(object):
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
speaker_manager = self._init_speaker_manager()
+ language_manager = self._init_language_manager()
+ if not self.encoder_checkpoint:
+ self._set_speaker_encoder_paths_from_tts_config()
+ speaker_manager = self._init_speaker_encoder(speaker_manager)
- self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
+ if language_manager is not None:
+ self.tts_model = setup_tts_model(
+ config=self.tts_config,
+ speaker_manager=speaker_manager,
+ language_manager=language_manager,
+ )
+ else:
+ self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
+ def _set_speaker_encoder_paths_from_tts_config(self):
+ """Set the encoder paths from the tts model config for models with speaker encoders."""
+ if hasattr(self.tts_config, "model_args") and hasattr(
+ self.tts_config.model_args, "speaker_encoder_config_path"
+ ):
+ self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
+ self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
+
+ def _is_use_speaker_embedding(self):
+ """Check if the speaker embedding is used in the model"""
+ # we handle here the case that some models use model_args some don't
+ use_speaker_embedding = False
+ if hasattr(self.tts_config, "model_args"):
+ use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False)
+ use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False)
+ return use_speaker_embedding
+
+ def _is_use_d_vector_file(self):
+ """Check if the d-vector file is used in the model"""
+ # we handle here the case that some models use model_args some don't
+ use_d_vector_file = False
+ if hasattr(self.tts_config, "model_args"):
+ config = self.tts_config.model_args
+ use_d_vector_file = config.get("use_d_vector_file", False)
+ config = self.tts_config
+ use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False)
+ return use_d_vector_file
+
def _init_speaker_manager(self):
"""Initialize the SpeakerManager"""
# setup if multi-speaker settings are in the global model config
speaker_manager = None
- if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
+ speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
+ if self._is_use_speaker_embedding():
if self.tts_speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
- if self.tts_config.get("speakers_file", None):
- speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
+ elif speakers_file:
+ speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
- if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
+ if self._is_use_d_vector_file():
+ d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
if self.tts_speakers_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
- if self.tts_config.get("d_vector_file", None):
- speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
+ elif d_vector_file:
+ speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
return speaker_manager
+ def _init_speaker_encoder(self, speaker_manager):
+ """Initialize the SpeakerEncoder"""
+ if self.encoder_checkpoint:
+ if speaker_manager is None:
+ speaker_manager = SpeakerManager(
+ encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config
+ )
+ else:
+ speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
+ return speaker_manager
+
+ def _init_language_manager(self):
+ """Initialize the LanguageManager"""
+ # setup if multi-lingual settings are in the global model config
+ language_manager = None
+ if check_config_and_model_args(self.tts_config, "use_language_embedding", True):
+ if self.tts_languages_file:
+ language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
+ elif self.tts_config.get("language_ids_file", None):
+ language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
+ else:
+ language_manager = LanguageManager(config=self.tts_config)
+ return language_manager
+
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model.
@@ -174,13 +245,21 @@ class Synthesizer(object):
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
- def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
+ def tts(
+ self,
+ text: str,
+ speaker_name: str = "",
+ language_name: str = "",
+ speaker_wav: Union[str, List[str]] = None,
+ style_wav=None,
+ ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Args:
text (str): input text.
- speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
- speaker_wav ():
+ speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
+ language_name (str, optional): language id for multi-language models. Defaults to "".
+ speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
Returns:
@@ -196,29 +275,49 @@ class Synthesizer(object):
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
- if speaker_idx and isinstance(speaker_idx, str):
+ if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
- speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
+ speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0]
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
- speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
+ speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
- elif not speaker_idx and not speaker_wav:
+ elif not speaker_name and not speaker_wav:
raise ValueError(
" [!] Look like you use a multi-speaker model. "
- "You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model."
+ "You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model."
)
else:
speaker_embedding = None
else:
- if speaker_idx:
+ if speaker_name:
raise ValueError(
- f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}."
+ f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
+ # handle multi-lingaul
+ language_id = None
+ if self.tts_languages_file or (
+ hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
+ ):
+ if language_name and isinstance(language_name, str):
+ language_id = self.tts_model.language_manager.language_id_mapping[language_name]
+
+ elif not language_name:
+ raise ValueError(
+ " [!] Look like you use a multi-lingual model. "
+ "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
+ )
+
+ else:
+ raise ValueError(
+ f" [!] Missing language_ids.json file path for selecting language {language_name}."
+ "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
+ )
+
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
@@ -234,6 +333,8 @@ class Synthesizer(object):
use_cuda=self.use_cuda,
ap=self.ap,
speaker_id=speaker_id,
+ language_id=language_id,
+ language_name=language_name,
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,
diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py
new file mode 100644
index 00000000..923544d0
--- /dev/null
+++ b/TTS/utils/vad.py
@@ -0,0 +1,144 @@
+# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
+import collections
+import contextlib
+import wave
+
+import webrtcvad
+
+
+def read_wave(path):
+ """Reads a .wav file.
+
+ Takes the path, and returns (PCM audio data, sample rate).
+ """
+ with contextlib.closing(wave.open(path, "rb")) as wf:
+ num_channels = wf.getnchannels()
+ assert num_channels == 1
+ sample_width = wf.getsampwidth()
+ assert sample_width == 2
+ sample_rate = wf.getframerate()
+ assert sample_rate in (8000, 16000, 32000, 48000)
+ pcm_data = wf.readframes(wf.getnframes())
+ return pcm_data, sample_rate
+
+
+def write_wave(path, audio, sample_rate):
+ """Writes a .wav file.
+
+ Takes path, PCM audio data, and sample rate.
+ """
+ with contextlib.closing(wave.open(path, "wb")) as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate)
+ wf.writeframes(audio)
+
+
+class Frame(object):
+ """Represents a "frame" of audio data."""
+
+ def __init__(self, _bytes, timestamp, duration):
+ self.bytes = _bytes
+ self.timestamp = timestamp
+ self.duration = duration
+
+
+def frame_generator(frame_duration_ms, audio, sample_rate):
+ """Generates audio frames from PCM audio data.
+
+ Takes the desired frame duration in milliseconds, the PCM data, and
+ the sample rate.
+
+ Yields Frames of the requested duration.
+ """
+ n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
+ offset = 0
+ timestamp = 0.0
+ duration = (float(n) / sample_rate) / 2.0
+ while offset + n < len(audio):
+ yield Frame(audio[offset : offset + n], timestamp, duration)
+ timestamp += duration
+ offset += n
+
+
+def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
+ """Filters out non-voiced audio frames.
+
+ Given a webrtcvad.Vad and a source of audio frames, yields only
+ the voiced audio.
+
+ Uses a padded, sliding window algorithm over the audio frames.
+ When more than 90% of the frames in the window are voiced (as
+ reported by the VAD), the collector triggers and begins yielding
+ audio frames. Then the collector waits until 90% of the frames in
+ the window are unvoiced to detrigger.
+
+ The window is padded at the front and back to provide a small
+ amount of silence or the beginnings/endings of speech around the
+ voiced frames.
+
+ Arguments:
+
+ sample_rate - The audio sample rate, in Hz.
+ frame_duration_ms - The frame duration in milliseconds.
+ padding_duration_ms - The amount to pad the window, in milliseconds.
+ vad - An instance of webrtcvad.Vad.
+ frames - a source of audio frames (sequence or generator).
+
+ Returns: A generator that yields PCM audio data.
+ """
+ num_padding_frames = int(padding_duration_ms / frame_duration_ms)
+ # We use a deque for our sliding window/ring buffer.
+ ring_buffer = collections.deque(maxlen=num_padding_frames)
+ # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
+ # NOTTRIGGERED state.
+ triggered = False
+
+ voiced_frames = []
+ for frame in frames:
+ is_speech = vad.is_speech(frame.bytes, sample_rate)
+
+ # sys.stdout.write('1' if is_speech else '0')
+ if not triggered:
+ ring_buffer.append((frame, is_speech))
+ num_voiced = len([f for f, speech in ring_buffer if speech])
+ # If we're NOTTRIGGERED and more than 90% of the frames in
+ # the ring buffer are voiced frames, then enter the
+ # TRIGGERED state.
+ if num_voiced > 0.9 * ring_buffer.maxlen:
+ triggered = True
+ # sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,))
+ # We want to yield all the audio we see from now until
+ # we are NOTTRIGGERED, but we have to start with the
+ # audio that's already in the ring buffer.
+ for f, _ in ring_buffer:
+ voiced_frames.append(f)
+ ring_buffer.clear()
+ else:
+ # We're in the TRIGGERED state, so collect the audio data
+ # and add it to the ring buffer.
+ voiced_frames.append(frame)
+ ring_buffer.append((frame, is_speech))
+ num_unvoiced = len([f for f, speech in ring_buffer if not speech])
+ # If more than 90% of the frames in the ring buffer are
+ # unvoiced, then enter NOTTRIGGERED and yield whatever
+ # audio we've collected.
+ if num_unvoiced > 0.9 * ring_buffer.maxlen:
+ # sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
+ triggered = False
+ yield b"".join([f.bytes for f in voiced_frames])
+ ring_buffer.clear()
+ voiced_frames = []
+ # If we have any leftover voiced audio when we run out of input,
+ # yield it.
+ if voiced_frames:
+ yield b"".join([f.bytes for f in voiced_frames])
+
+
+def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
+
+ vad = webrtcvad.Vad(int(aggressiveness))
+ frames = list(frame_generator(30, audio, sample_rate))
+ segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames)
+
+ return segments
diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py
index c5d6a8b4..9ff6f790 100644
--- a/TTS/vocoder/configs/shared_configs.py
+++ b/TTS/vocoder/configs/shared_configs.py
@@ -113,8 +113,10 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
lr_scheduler_disc (torch.optim.Scheduler):
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
- lr_scheduler_dict_params (dict):
+ lr_scheduler_disc_params (dict):
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
+ scheduler_after_epoch (bool):
+ Whether to update the learning rate schedulers after each epoch. Defaults to True.
use_pqmf (bool):
enable / disable PQMF for subband approximation at training. Defaults to False.
steps_to_start_discriminator (int):
@@ -173,6 +175,7 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
+ scheduler_after_epoch: bool = True
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py
index e36c2cd1..76fee505 100644
--- a/TTS/vocoder/models/gan.py
+++ b/TTS/vocoder/models/gan.py
@@ -202,7 +202,9 @@ class GAN(BaseVocoder):
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
ap = assets["audio_processor"]
- self._log("train", ap, batch, outputs)
+ figures, audios = self._log("eval", ap, batch, outputs)
+ logger.eval_figures(steps, figures)
+ logger.eval_audios(steps, audios, ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
@@ -214,7 +216,9 @@ class GAN(BaseVocoder):
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
ap = assets["audio_processor"]
- self._log("eval", ap, batch, outputs)
+ figures, audios = self._log("eval", ap, batch, outputs)
+ logger.eval_figures(steps, figures)
+ logger.eval_audios(steps, audios, ap.sample_rate)
def load_checkpoint(
self,
diff --git a/docs/source/models/vits.md b/docs/source/models/vits.md
index 5c0e92f6..0c303f7a 100644
--- a/docs/source/models/vits.md
+++ b/docs/source/models/vits.md
@@ -3,10 +3,15 @@
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
-using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
+using MAS, as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
It is a feed-forward model with x67.12 real-time factor on a GPU.
+🐸 YourTTS is a multi-speaker and multi-lingual TTS model that can perform voice conversion and zero-shot speaker adaptation.
+It can also learn a new language or voice with a ~ 1 minute long audio clip. This is a big open gate for training
+TTS models in low-resources languages. 🐸 YourTTS uses VITS as the backbone architecture coupled with a speaker encoder model.
+
## Important resources & papers
+- 🐸 YourTTS: https://arxiv.org/abs/2112.02418
- VITS: https://arxiv.org/pdf/2106.06103.pdf
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py
index 9ba42fb9..4855886e 100644
--- a/notebooks/dataset_analysis/analyze.py
+++ b/notebooks/dataset_analysis/analyze.py
@@ -180,7 +180,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)
- barplot = sns.barplot(x, y)
+ barplot = sns.barplot(x=x, y=y)
if save_path:
fig = barplot.get_figure()
fig.savefig(os.path.join(save_path, "phoneme_dist"))
diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py
new file mode 100644
index 00000000..be4747df
--- /dev/null
+++ b/recipes/multilingual/vits_tts/train_vits_tts.py
@@ -0,0 +1,130 @@
+import os
+from glob import glob
+
+from TTS.config.shared_configs import BaseAudioConfig
+from TTS.trainer import Trainer, TrainingArgs
+from TTS.tts.configs.shared_configs import BaseDatasetConfig
+from TTS.tts.configs.vits_config import VitsConfig
+from TTS.tts.datasets import load_tts_samples
+from TTS.tts.models.vits import Vits, VitsArgs
+from TTS.tts.utils.languages import LanguageManager
+from TTS.tts.utils.speakers import SpeakerManager
+from TTS.utils.audio import AudioProcessor
+
+output_path = os.path.dirname(os.path.abspath(__file__))
+
+mailabs_path = "/home/julian/workspace/mailabs/**"
+dataset_paths = glob(mailabs_path)
+dataset_config = [
+ BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
+ for path in dataset_paths
+]
+
+audio_config = BaseAudioConfig(
+ sample_rate=16000,
+ win_length=1024,
+ hop_length=256,
+ num_mels=80,
+ preemphasis=0.0,
+ ref_level_db=20,
+ log_func="np.log",
+ do_trim_silence=False,
+ trim_db=23.0,
+ mel_fmin=0,
+ mel_fmax=None,
+ spec_gain=1.0,
+ signal_norm=True,
+ do_amp_to_db_linear=False,
+ resample=False,
+)
+
+vitsArgs = VitsArgs(
+ use_language_embedding=True,
+ embedded_language_dim=4,
+ use_speaker_embedding=True,
+ use_sdp=False,
+)
+
+config = VitsConfig(
+ model_args=vitsArgs,
+ audio=audio_config,
+ run_name="vits_vctk",
+ use_speaker_embedding=True,
+ batch_size=32,
+ eval_batch_size=16,
+ batch_group_size=0,
+ num_loader_workers=4,
+ num_eval_loader_workers=4,
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1000,
+ text_cleaner="multilingual_cleaners",
+ use_phonemes=False,
+ phoneme_language="en-us",
+ phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
+ compute_input_seq_cache=True,
+ print_step=25,
+ use_language_weighted_sampler=True,
+ print_eval=False,
+ mixed_precision=False,
+ sort_by_audio_len=True,
+ min_seq_len=32 * 256 * 4,
+ max_seq_len=160000,
+ output_path=output_path,
+ datasets=dataset_config,
+ characters={
+ "pad": "_",
+ "eos": "&",
+ "bos": "*",
+ "characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
+ "punctuations": "!¡'(),-.:;¿? ",
+ "phonemes": None,
+ "unique": True,
+ },
+ test_sentences=[
+ [
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
+ "mary_ann",
+ None,
+ "en_US",
+ ],
+ [
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
+ "ezwa",
+ None,
+ "fr_FR",
+ ],
+ ["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
+ ["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
+ ],
+)
+
+# init audio processor
+ap = AudioProcessor(**config.audio.to_dict())
+
+# load training samples
+train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
+
+# init speaker manager for multi-speaker training
+# it maps speaker-id to speaker-name in the model and data-loader
+speaker_manager = SpeakerManager()
+speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
+config.model_args.num_speakers = speaker_manager.num_speakers
+
+language_manager = LanguageManager(config=config)
+config.model_args.num_languages = language_manager.num_languages
+
+# init model
+model = Vits(config, speaker_manager, language_manager)
+
+# init the trainer and 🚀
+trainer = Trainer(
+ TrainingArgs(),
+ config,
+ output_path,
+ model=model,
+ train_samples=train_samples,
+ eval_samples=eval_samples,
+ training_assets={"audio_processor": ap},
+)
+trainer.fit()
diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py
index 19074ce3..7eb741c4 100644
--- a/recipes/vctk/vits/train_vits.py
+++ b/recipes/vctk/vits/train_vits.py
@@ -5,12 +5,14 @@ from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
-from TTS.tts.models.vits import Vits
+from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
-dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
+dataset_config = BaseDatasetConfig(
+ name="vctk", meta_file_train="", language="en-us", path=os.path.join(output_path, "../VCTK/")
+)
audio_config = BaseAudioConfig(
@@ -31,10 +33,14 @@ audio_config = BaseAudioConfig(
resample=True,
)
+vitsArgs = VitsArgs(
+ use_speaker_embedding=True,
+)
+
config = VitsConfig(
+ model_args=vitsArgs,
audio=audio_config,
run_name="vits_vctk",
- use_speaker_embedding=True,
batch_size=32,
eval_batch_size=16,
batch_group_size=5,
@@ -45,7 +51,6 @@ config = VitsConfig(
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
- phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
diff --git a/requirements.txt b/requirements.txt
index 3ec33ceb..ddb6def9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -26,3 +26,5 @@ unidic-lite==1.0.8
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0
pyworld
+webrtcvad
+torchaudio
diff --git a/tests/__init__.py b/tests/__init__.py
index 45aee23a..0a0c3379 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -38,3 +38,14 @@ def run_cli(command):
def get_test_data_config():
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
+
+
+def assertHasAttr(test_obj, obj, intendedAttr):
+ # from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead
+ testBool = hasattr(obj, intendedAttr)
+ test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
+
+
+def assertHasNotAttr(test_obj, obj, intendedAttr):
+ testBool = hasattr(obj, intendedAttr)
+ test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py
new file mode 100644
index 00000000..fa0abe4b
--- /dev/null
+++ b/tests/aux_tests/test_find_unique_phonemes.py
@@ -0,0 +1,80 @@
+import os
+import unittest
+
+import torch
+
+from tests import get_tests_output_path, run_cli
+from TTS.config.shared_configs import BaseDatasetConfig
+from TTS.tts.configs.vits_config import VitsConfig
+
+torch.manual_seed(1)
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+
+dataset_config_en = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="en",
+)
+
+dataset_config_pt = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="pt-br",
+)
+
+# pylint: disable=protected-access
+class TestFindUniquePhonemes(unittest.TestCase):
+ @staticmethod
+ def test_espeak_phonemes():
+ # prepare the config
+ config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=True,
+ use_espeak_phonemes=True,
+ phoneme_language="en-us",
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ datasets=[dataset_config_en, dataset_config_pt],
+ )
+ config.save_json(config_path)
+
+ # run test
+ run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
+
+ @staticmethod
+ def test_no_espeak_phonemes():
+ # prepare the config
+ config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=True,
+ use_espeak_phonemes=False,
+ phoneme_language="en-us",
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ datasets=[dataset_config_en, dataset_config_pt],
+ )
+ config.save_json(config_path)
+
+ # run test
+ run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
diff --git a/tests/aux_tests/test_remove_silence_vad_script.py b/tests/aux_tests/test_remove_silence_vad_script.py
new file mode 100644
index 00000000..c934e065
--- /dev/null
+++ b/tests/aux_tests/test_remove_silence_vad_script.py
@@ -0,0 +1,29 @@
+import os
+import unittest
+
+import torch
+
+from tests import get_tests_input_path, get_tests_output_path, run_cli
+
+torch.manual_seed(1)
+
+# pylint: disable=protected-access
+class TestRemoveSilenceVAD(unittest.TestCase):
+ @staticmethod
+ def test():
+ # set paths
+ wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
+ output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
+ output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
+
+ # resample audios
+ run_cli(
+ f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
+ )
+
+ # run test
+ run_cli(
+ f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
+ )
+ run_cli(f'rm -rf "{output_resample_path}"')
+ run_cli(f'rm -rf "{output_path}"')
diff --git a/tests/aux_tests/test_speaker_encoder.py b/tests/aux_tests/test_speaker_encoder.py
index 3c897aa9..97b3b92f 100644
--- a/tests/aux_tests/test_speaker_encoder.py
+++ b/tests/aux_tests/test_speaker_encoder.py
@@ -13,7 +13,7 @@ file_path = get_tests_input_path()
class LSTMSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
- dummy_input = T.rand(4, 20, 80) # B x T x D
+ dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
# computing d vectors
@@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
- dummy_input = T.rand(1, 240, 80) # B x T x D
+ dummy_input = T.rand(1, 80, 240) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
assert output.shape[0] == 1
assert output.shape[1] == 256
@@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
- dummy_input = T.rand(4, 20, 80) # B x T x D
+ dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
# computing d vectors
@@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
- dummy_input = T.rand(1, 240, 80) # B x T x D
+ dummy_input = T.rand(1, 80, 240) # B x D x T
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
assert output.shape[0] == 1
assert output.shape[1] == 256
diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py
index baa50749..fff49b13 100644
--- a/tests/aux_tests/test_speaker_manager.py
+++ b/tests/aux_tests/test_speaker_manager.py
@@ -6,7 +6,7 @@ import torch
from tests import get_tests_input_path
from TTS.config import load_config
-from TTS.speaker_encoder.utils.generic_utils import setup_model
+from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.speaker_encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
config.audio.resample = True
# create a dummy speaker encoder
- model = setup_model(config)
+ model = setup_speaker_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
# load audio processor and speaker encoder
@@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
# load a sample audio and compute embedding
waveform = ap.load_wav(sample_wav_path)
mel = ap.melspectrogram(waveform)
- d_vector = manager.compute_d_vector(mel.T)
+ d_vector = manager.compute_d_vector(mel)
assert d_vector.shape[1] == 256
# compute d_vector directly from an input file
diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py
index 8a20c261..19c2e8f7 100644
--- a/tests/data_tests/test_loader.py
+++ b/tests/data_tests/test_loader.py
@@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase):
def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv")
+
+ # add a default language because now the TTSDataset expect a language
+ language = ""
+ items = [[*item, language] for item in items]
+
dataset = TTSDataset(
r,
c.text_cleaner,
diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py
new file mode 100644
index 00000000..3d8d6c75
--- /dev/null
+++ b/tests/data_tests/test_samplers.py
@@ -0,0 +1,58 @@
+import functools
+
+import torch
+
+from TTS.config.shared_configs import BaseDatasetConfig
+from TTS.tts.datasets import load_tts_samples
+from TTS.tts.utils.languages import get_language_weighted_sampler
+
+# Fixing random state to avoid random fails
+torch.manual_seed(0)
+
+dataset_config_en = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="en",
+)
+
+dataset_config_pt = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="pt-br",
+)
+
+# Adding the EN samples twice to create an unbalanced dataset
+train_samples, eval_samples = load_tts_samples(
+ [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
+)
+
+
+def is_balanced(lang_1, lang_2):
+ return 0.85 < lang_1 / lang_2 < 1.2
+
+
+random_sampler = torch.utils.data.RandomSampler(train_samples)
+ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
+en, pt = 0, 0
+for index in ids:
+ if train_samples[index][3] == "en":
+ en += 1
+ else:
+ pt += 1
+
+assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
+
+weighted_sampler = get_language_weighted_sampler(train_samples)
+ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
+en, pt = 0, 0
+for index in ids:
+ if train_samples[index][3] == "en":
+ en += 1
+ else:
+ pt += 1
+
+assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
diff --git a/tests/inputs/language_ids.json b/tests/inputs/language_ids.json
new file mode 100644
index 00000000..27bb1520
--- /dev/null
+++ b/tests/inputs/language_ids.json
@@ -0,0 +1,5 @@
+{
+ "en": 0,
+ "fr-fr": 1,
+ "pt-br": 2
+}
\ No newline at end of file
diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py
new file mode 100644
index 00000000..4274d947
--- /dev/null
+++ b/tests/tts_tests/test_vits.py
@@ -0,0 +1,240 @@
+import os
+import unittest
+
+import torch
+
+from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
+from TTS.config import load_config
+from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
+from TTS.tts.configs.vits_config import VitsConfig
+from TTS.tts.models.vits import Vits, VitsArgs
+from TTS.tts.utils.speakers import SpeakerManager
+
+LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
+SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
+
+
+torch.manual_seed(1)
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+# pylint: disable=no-self-use
+class TestVits(unittest.TestCase):
+ def test_init_multispeaker(self):
+ num_speakers = 10
+ args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
+ model = Vits(args)
+ assertHasAttr(self, model, "emb_g")
+
+ args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
+ model = Vits(args)
+ assertHasNotAttr(self, model, "emb_g")
+
+ args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
+ model = Vits(args)
+ assertHasNotAttr(self, model, "emb_g")
+
+ args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
+ model = Vits(args)
+ self.assertEqual(model.embedded_speaker_dim, 101)
+
+ def test_init_multilingual(self):
+ args = VitsArgs(language_ids_file=None, use_language_embedding=False)
+ model = Vits(args)
+ self.assertEqual(model.language_manager, None)
+ self.assertEqual(model.embedded_language_dim, 0)
+ self.assertEqual(model.emb_l, None)
+
+ args = VitsArgs(language_ids_file=LANG_FILE)
+ model = Vits(args)
+ self.assertNotEqual(model.language_manager, None)
+ self.assertEqual(model.embedded_language_dim, 0)
+ self.assertEqual(model.emb_l, None)
+
+ args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
+ model = Vits(args)
+ self.assertNotEqual(model.language_manager, None)
+ self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
+ self.assertNotEqual(model.emb_l, None)
+
+ args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
+ model = Vits(args)
+ self.assertNotEqual(model.language_manager, None)
+ self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
+ self.assertNotEqual(model.emb_l, None)
+
+ def test_get_aux_input(self):
+ aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
+ args = VitsArgs()
+ model = Vits(args)
+ aux_out = model.get_aux_input(aux_input)
+
+ speaker_id = torch.randint(10, (1,))
+ language_id = torch.randint(10, (1,))
+ d_vector = torch.rand(1, 128)
+ aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
+ aux_out = model.get_aux_input(aux_input)
+ self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
+ self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
+ self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
+
+ def test_voice_conversion(self):
+ num_speakers = 10
+ spec_len = 101
+ spec_effective_len = 50
+
+ args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
+ model = Vits(args)
+
+ ref_inp = torch.randn(1, spec_len, 513)
+ ref_inp_len = torch.randint(1, spec_effective_len, (1,))
+ ref_spk_id = torch.randint(1, num_speakers, (1,))
+ tgt_spk_id = torch.randint(1, num_speakers, (1,))
+ o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
+
+ self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
+ self.assertEqual(y_mask.shape, (1, 1, spec_len))
+ self.assertEqual(y_mask.sum(), ref_inp_len[0])
+ self.assertEqual(z.shape, (1, args.hidden_channels, spec_len))
+ self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
+ self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))
+
+ def _init_inputs(self, config):
+ input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
+ input_lengths = torch.randint(100, 129, (8,)).long().to(device)
+ input_lengths[-1] = 128
+ spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device)
+ spec_lengths = torch.randint(20, 30, (8,)).long().to(device)
+ spec_lengths[-1] = spec.size(2)
+ waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device)
+ return input_dummy, input_lengths, spec, spec_lengths, waveform
+
+ def _check_forward_outputs(self, config, output_dict, encoder_config=None):
+ self.assertEqual(
+ output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
+ )
+ self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
+ self.assertEqual(output_dict["alignments"].max(), 1)
+ self.assertEqual(output_dict["alignments"].min(), 0)
+ self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
+ self.assertEqual(
+ output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
+ )
+ if encoder_config:
+ self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
+ self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
+ else:
+ self.assertEqual(output_dict["gt_spk_emb"], None)
+ self.assertEqual(output_dict["syn_spk_emb"], None)
+
+ def test_forward(self):
+ num_speakers = 0
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
+ config.model_args.spec_segment_size = 10
+ input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
+ model = Vits(config).to(device)
+ output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
+ self._check_forward_outputs(config, output_dict)
+
+ def test_multispeaker_forward(self):
+ num_speakers = 10
+
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
+ config.model_args.spec_segment_size = 10
+
+ input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
+ speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
+
+ model = Vits(config).to(device)
+ output_dict = model.forward(
+ input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
+ )
+ self._check_forward_outputs(config, output_dict)
+
+ def test_multilingual_forward(self):
+ num_speakers = 10
+ num_langs = 3
+
+ args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
+
+ input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
+ speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
+ lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
+
+ model = Vits(config).to(device)
+ output_dict = model.forward(
+ input_dummy,
+ input_lengths,
+ spec,
+ spec_lengths,
+ waveform,
+ aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
+ )
+ self._check_forward_outputs(config, output_dict)
+
+ def test_secl_forward(self):
+ num_speakers = 10
+ num_langs = 3
+
+ speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
+ speaker_encoder_config.model_params["use_torch_spec"] = True
+ speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
+ speaker_manager = SpeakerManager()
+ speaker_manager.speaker_encoder = speaker_encoder
+
+ args = VitsArgs(
+ language_ids_file=LANG_FILE,
+ use_language_embedding=True,
+ spec_segment_size=10,
+ use_speaker_encoder_as_loss=True,
+ )
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
+ config.audio.sample_rate = 16000
+
+ input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
+ speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
+ lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
+
+ model = Vits(config, speaker_manager=speaker_manager).to(device)
+ output_dict = model.forward(
+ input_dummy,
+ input_lengths,
+ spec,
+ spec_lengths,
+ waveform,
+ aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
+ )
+ self._check_forward_outputs(config, output_dict, speaker_encoder_config)
+
+ def test_inference(self):
+ num_speakers = 0
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
+ input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
+ model = Vits(config).to(device)
+ _ = model.inference(input_dummy)
+
+ def test_multispeaker_inference(self):
+ num_speakers = 10
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
+ input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
+ speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
+ model = Vits(config).to(device)
+ _ = model.inference(input_dummy, {"speaker_ids": speaker_ids})
+
+ def test_multilingual_inference(self):
+ num_speakers = 10
+ num_langs = 3
+ args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
+ config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
+ input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
+ speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
+ lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
+ model = Vits(config).to(device)
+ _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py
new file mode 100644
index 00000000..213669f5
--- /dev/null
+++ b/tests/tts_tests/test_vits_d-vectors_train.py
@@ -0,0 +1,62 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.tts.configs.vits_config import VitsConfig
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+
+config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=True,
+ use_espeak_phonemes=True,
+ phoneme_language="en-us",
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ test_sentences=[
+ ["Be a voice, not an echo.", "ljspeech-0"],
+ ],
+)
+# set audio config
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+
+# active multispeaker d-vec mode
+config.model_args.use_d_vector_file = True
+config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
+config.model_args.d_vector_dim = 256
+
+
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
+ f"--coqpit.output_path {output_path} "
+ "--coqpit.datasets.0.name ljspeech "
+ "--coqpit.datasets.0.meta_file_train metadata.csv "
+ "--coqpit.datasets.0.meta_file_val metadata.csv "
+ "--coqpit.datasets.0.path tests/data/ljspeech "
+ "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
+ "--coqpit.test_delay_epochs 0"
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
+run_cli(command_train)
+shutil.rmtree(continue_path)
diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py
new file mode 100644
index 00000000..1ca57d93
--- /dev/null
+++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py
@@ -0,0 +1,91 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.config.shared_configs import BaseDatasetConfig
+from TTS.tts.configs.vits_config import VitsConfig
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+
+dataset_config_en = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="en",
+)
+
+dataset_config_pt = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="pt-br",
+)
+
+config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=False,
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ test_sentences=[
+ ["Be a voice, not an echo.", "ljspeech-0", None, "en"],
+ ["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
+ ],
+ datasets=[dataset_config_en, dataset_config_pt],
+)
+# set audio config
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+
+# active multilingual mode
+config.model_args.use_language_embedding = True
+config.use_language_embedding = True
+
+# deactivate multispeaker mode
+config.model_args.use_speaker_embedding = False
+config.use_speaker_embedding = False
+
+# active multispeaker d-vec mode
+config.model_args.use_d_vector_file = True
+config.use_d_vector_file = True
+config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
+config.d_vector_file = "tests/data/ljspeech/speakers.json"
+config.model_args.d_vector_dim = 256
+config.d_vector_dim = 256
+
+# duration predictor
+config.model_args.use_sdp = True
+config.use_sdp = True
+
+# deactivate language sampler
+config.use_language_weighted_sampler = False
+
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
+ f"--coqpit.output_path {output_path} "
+ "--coqpit.test_delay_epochs 0"
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
+run_cli(command_train)
+shutil.rmtree(continue_path)
diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_train.py
new file mode 100644
index 00000000..50cccca5
--- /dev/null
+++ b/tests/tts_tests/test_vits_multilingual_train.py
@@ -0,0 +1,88 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.config.shared_configs import BaseDatasetConfig
+from TTS.tts.configs.vits_config import VitsConfig
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+
+dataset_config_en = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="en",
+)
+
+dataset_config_pt = BaseDatasetConfig(
+ name="ljspeech",
+ meta_file_train="metadata.csv",
+ meta_file_val="metadata.csv",
+ path="tests/data/ljspeech",
+ language="pt-br",
+)
+
+config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=True,
+ use_espeak_phonemes=True,
+ phoneme_language="en-us",
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ test_sentences=[
+ ["Be a voice, not an echo.", "ljspeech", None, "en"],
+ ["Be a voice, not an echo.", "ljspeech", None, "pt-br"],
+ ],
+ datasets=[dataset_config_en, dataset_config_pt],
+)
+# set audio config
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+
+# active multilingual mode
+config.model_args.use_language_embedding = True
+config.use_language_embedding = True
+# active multispeaker mode
+config.model_args.use_speaker_embedding = True
+config.use_speaker_embedding = True
+
+# deactivate multispeaker d-vec mode
+config.model_args.use_d_vector_file = False
+config.use_d_vector_file = False
+
+# duration predictor
+config.model_args.use_sdp = False
+config.use_sdp = False
+
+# active language sampler
+config.use_language_weighted_sampler = True
+
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
+ f"--coqpit.output_path {output_path} "
+ "--coqpit.test_delay_epochs 0"
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
+run_cli(command_train)
+shutil.rmtree(continue_path)
diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py
new file mode 100644
index 00000000..6cc1dabd
--- /dev/null
+++ b/tests/tts_tests/test_vits_speaker_emb_train.py
@@ -0,0 +1,63 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.tts.configs.vits_config import VitsConfig
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+
+config = VitsConfig(
+ batch_size=2,
+ eval_batch_size=2,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=True,
+ use_espeak_phonemes=True,
+ phoneme_language="en-us",
+ phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ print_eval=True,
+ test_sentences=[
+ ["Be a voice, not an echo.", "ljspeech"],
+ ],
+)
+# set audio config
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+
+# active multispeaker d-vec mode
+config.model_args.use_speaker_embedding = True
+config.model_args.use_d_vector_file = False
+config.model_args.d_vector_file = None
+config.model_args.d_vector_dim = 256
+
+
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
+ f"--coqpit.output_path {output_path} "
+ "--coqpit.datasets.0.name ljspeech "
+ "--coqpit.datasets.0.meta_file_train metadata.csv "
+ "--coqpit.datasets.0.meta_file_val metadata.csv "
+ "--coqpit.datasets.0.path tests/data/ljspeech "
+ "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
+ "--coqpit.test_delay_epochs 0"
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
+run_cli(command_train)
+shutil.rmtree(continue_path)
diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py
index 6398955e..607f7b29 100644
--- a/tests/tts_tests/test_vits_train.py
+++ b/tests/tts_tests/test_vits_train.py
@@ -25,7 +25,7 @@ config = VitsConfig(
print_step=1,
print_eval=True,
test_sentences=[
- "Be a voice, not an echo.",
+ ["Be a voice, not an echo."],
],
)
config.audio.do_trim_silence = True
diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py
index 886d1bb6..63d9e7ca 100644
--- a/tests/zoo_tests/test_models.py
+++ b/tests/zoo_tests/test_models.py
@@ -4,6 +4,7 @@ import os
import shutil
from tests import get_tests_output_path, run_cli
+from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager
@@ -17,21 +18,30 @@ def test_run_all_models():
manager = ModelManager(output_prefix=get_tests_output_path())
model_names = manager.list_models()
for model_name in model_names:
+ print(f"\n > Run - {model_name}")
model_path, _, _ = manager.download_model(model_name)
if "tts_models" in model_name:
local_download_dir = os.path.dirname(model_path)
# download and run the model
speaker_files = glob.glob(local_download_dir + "/speaker*")
+ language_files = glob.glob(local_download_dir + "/language*")
+ language_id = ""
if len(speaker_files) > 0:
# multi-speaker model
if "speaker_ids" in speaker_files[0]:
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
elif "speakers" in speaker_files[0]:
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
+
+ # multi-lingual model - Assuming multi-lingual models are also multi-speaker
+ if len(language_files) > 0 and "language_ids" in language_files[0]:
+ language_manager = LanguageManager(language_ids_file_path=language_files[0])
+ language_id = language_manager.language_names[0]
+
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
run_cli(
f"tts --model_name {model_name} "
- f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
+ f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
)
else:
# single-speaker model