mirror of https://github.com/coqui-ai/TTS.git
commit
df2422eb72
|
@ -0,0 +1,52 @@
|
|||
name: zoo-tests-tortoise
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise
|
141
README.md
141
README.md
|
@ -1,6 +1,7 @@
|
|||
|
||||
## 🐸Coqui.ai News
|
||||
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with uncontrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
|
||||
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
|
||||
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
|
||||
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
|
||||
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html)
|
||||
- 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
|
||||
|
@ -111,7 +112,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
|
|||
- Delightful TTS: [paper](https://arxiv.org/abs/2110.12612)
|
||||
|
||||
### End-to-End Models
|
||||
- ⓍTTS: [blog]()
|
||||
- ⓍTTS: [blog](https://coqui.ai/blog/tts/open_xtts)
|
||||
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
|
||||
- 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418)
|
||||
- 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts)
|
||||
|
@ -293,99 +294,123 @@ api.tts_with_vc_to_file(
|
|||
```
|
||||
|
||||
### Command-line `tts`
|
||||
|
||||
<!-- begin-tts-readme -->
|
||||
|
||||
Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
#### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
```
|
||||
```
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
For example:
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
#### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose a <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
|
||||
<!-- end-tts-readme -->
|
||||
|
||||
## Directory Structure
|
||||
```
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
"xtts_v1": {
|
||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
|
||||
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
|
||||
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
|
||||
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
|
||||
],
|
||||
"default_vocoder": null,
|
||||
"commit": "e9a1953e",
|
||||
"commit": "e5140314",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
|
@ -728,6 +728,18 @@
|
|||
"license": "Apache 2.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
"be": {
|
||||
"common-voice": {
|
||||
"glow-tts":{
|
||||
"description": "Belarusian GlowTTS model created by @alex73 (Github).",
|
||||
"github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip",
|
||||
"default_vocoder": "vocoder_models/be/common-voice/hifigan",
|
||||
"commit": "c0aabb85",
|
||||
"license": "CC-BY-SA 4.0",
|
||||
"contact": "alex73mail@gmail.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"vocoder_models": {
|
||||
|
@ -879,6 +891,17 @@
|
|||
"commit": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"be": {
|
||||
"common-voice": {
|
||||
"hifigan": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip",
|
||||
"description": "Belarusian HiFiGAN model created by @alex73 (Github).",
|
||||
"author": "@alex73",
|
||||
"license": "CC-BY-SA 4.0",
|
||||
"commit": "c0aabb85"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"voice_conversion_models": {
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.17.2
|
||||
0.17.8
|
||||
|
|
|
@ -17,7 +17,7 @@ class TTS(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
model_name: str = "",
|
||||
model_path: str = None,
|
||||
config_path: str = None,
|
||||
vocoder_path: str = None,
|
||||
|
@ -105,13 +105,14 @@ class TTS(nn.Module):
|
|||
|
||||
@property
|
||||
def is_multi_lingual(self):
|
||||
# TODO: fix this
|
||||
if "xtts" in self.model_name:
|
||||
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
||||
if isinstance(self.model_name, str) and "xtts" in self.model_name:
|
||||
return True
|
||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
||||
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def speakers(self):
|
||||
if not self.is_multi_speaker:
|
||||
|
|
|
@ -8,9 +8,120 @@ from argparse import RawTextHelpFormatter
|
|||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
description = """
|
||||
Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
#### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
#### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose a <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
|
@ -24,92 +135,6 @@ def str2bool(v):
|
|||
|
||||
|
||||
def main():
|
||||
description = """Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
## Example Runs
|
||||
|
||||
### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Query info for model info by idx:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS"
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose as <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
||||
# documentation in sync more easily.
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description.replace(" ```\n", ""),
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
|
@ -310,6 +335,11 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
if not any(check_args):
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
# Late-import to make things load faster
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
|
|
|
@ -1085,31 +1085,6 @@ class GaussianDiffusion:
|
|||
}
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
|
|
@ -5,9 +5,13 @@ from tokenizers import Tokenizer
|
|||
|
||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
|
||||
)
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, vocab_str=None):
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
|
|
@ -1170,31 +1170,6 @@ class GaussianDiffusion:
|
|||
}
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
|
|
@ -172,7 +172,7 @@ class GPT(nn.Module):
|
|||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_gpt_for_inference(self, kv_cache=True):
|
||||
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
|
||||
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
|
@ -195,6 +195,17 @@ class GPT(nn.Module):
|
|||
)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(
|
||||
model=self.gpt_inference.half(), # Transformers models
|
||||
mp_size=1, # Number of GPU
|
||||
dtype=torch.float32, # desired data type of output
|
||||
replace_method="auto", # Lets DS autmatically identify the layer to replace
|
||||
replace_with_kernel_inject=True, # replace the model with the kernel injector
|
||||
)
|
||||
self.gpt_inference = self.ds_engine.module.eval()
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
|
@ -543,3 +554,14 @@ class GPT(nn.Module):
|
|||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
||||
return gen[:, gpt_inputs.shape[1] :]
|
||||
|
||||
def get_generator(self, fake_inputs, **hf_generate_kwargs):
|
||||
return self.gpt_inference.generate_stream(
|
||||
fake_inputs,
|
||||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
do_stream=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,658 +0,0 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||
|
||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.pos_embedding = pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.final_norm = norm
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
def store_prefix_emb(self, prefix_emb):
|
||||
self.cached_prefix_emb = prefix_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
||||
if not self.kv_cache:
|
||||
past_key_values = None
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values is not None:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert self.cached_prefix_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
||||
|
||||
# Create embedding
|
||||
prefix_len = self.cached_prefix_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
gen_inputs = input_ids[:, prefix_len:]
|
||||
gen_emb = self.embeddings(gen_inputs)
|
||||
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
||||
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
||||
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
||||
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
||||
)
|
||||
else:
|
||||
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
||||
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_channels)
|
||||
nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, x):
|
||||
seq_len = x.shape[1]
|
||||
if self.relative:
|
||||
start = torch.randint(seq_len, (1,), device=x.device).item()
|
||||
positions = torch.arange(start, start + seq_len, device=x.device)
|
||||
else:
|
||||
positions = torch.arange(seq_len, device=x.device)
|
||||
return self.emb(positions)
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
|
||||
"""
|
||||
Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
|
||||
|
||||
Args:
|
||||
layers (int): Number of layers in the GPT-2 model.
|
||||
model_channels (int): Dimension of the GPT-2 model.
|
||||
heads (int): Number of heads in the GPT-2 model.
|
||||
max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
|
||||
max_text_seq_len (int): Maximum sequence length for the text.
|
||||
max_prompt_len (int): Maximum length of the prompt.
|
||||
checkpointing (bool): Whether to use gradient checkpointing.
|
||||
|
||||
Returns:
|
||||
gpt (GPT2Model): GPT-2 model.
|
||||
mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
|
||||
text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
|
||||
"""
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=123,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_channels,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
del gpt.wpe
|
||||
del gpt.wte
|
||||
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
|
||||
|
||||
audio_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
|
||||
return gpt, audio_pos_emb, text_pos_emb
|
||||
|
||||
|
||||
class XTTSGPTEncoder(nn.Module):
|
||||
"""XTTS GPT Encoder model implementation.
|
||||
Args:
|
||||
start_text_token (int): Index of the start token in the text vocabulary.
|
||||
stop_text_token (int): Index of the stop token in the text vocabulary.
|
||||
n_layers (int): Number of layers in the GPT-2 model.
|
||||
n_model_channels (int): Dimension of the GPT-2 model.
|
||||
n_heads (int): Number of heads in the GPT-2 model.
|
||||
max_text_tokens (int): Maximum number of text tokens.
|
||||
max_audio_tokens (int): Maximum number of audio tokens.
|
||||
max_prompt_tokens (int): Maximum number of prompt tokens.
|
||||
audio_len_compression (int): Compression factor for the audio length.
|
||||
number_text_tokens (int): Number of text tokens.
|
||||
number_audio_codes (int): Number of audio codes.
|
||||
start_mel_token (int): Index of the start token in the mel code vocabulary.
|
||||
stop_mel_token (int): Index of the stop token in the mel code vocabulary.
|
||||
checkpointing (bool): Whether or not to use gradient checkpointing at training.
|
||||
"""
|
||||
|
||||
_inference_flag = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_text_token=261,
|
||||
stop_text_token=0,
|
||||
n_layers=8,
|
||||
n_model_channels=512,
|
||||
n_heads=8,
|
||||
max_text_tokens=120,
|
||||
max_audio_tokens=250,
|
||||
max_prompt_tokens=70,
|
||||
audio_len_compression=1024,
|
||||
number_text_tokens=256,
|
||||
number_audio_codes=8194,
|
||||
start_mel_token=8192,
|
||||
stop_mel_token=8193,
|
||||
checkpointing=True,
|
||||
label_smoothing=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.label_smoothing = label_smoothing
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = start_text_token
|
||||
self.stop_text_token = stop_text_token
|
||||
self.number_audio_codes = number_audio_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
self.start_prompt_token = start_mel_token
|
||||
self.stop_prompt_token = stop_mel_token
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.n_model_channels = n_model_channels
|
||||
self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
self.audio_len_compression = audio_len_compression
|
||||
|
||||
# embedding layers
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
|
||||
self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
|
||||
|
||||
# initialize the GPT-2 model
|
||||
(
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.text_pos_embedding,
|
||||
) = init_gpt(
|
||||
n_layers,
|
||||
n_model_channels,
|
||||
n_heads,
|
||||
self.max_audio_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
)
|
||||
|
||||
# output layers
|
||||
self.final_norm = nn.LayerNorm(n_model_channels)
|
||||
self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
|
||||
self._inference_flag = True
|
||||
seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_audio_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.n_model_channels,
|
||||
n_layer=self.n_layers,
|
||||
n_head=self.n_heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.audio_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
self.gpt.wte = self.audio_embedding
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
for b in range(len(audio_token_lens)):
|
||||
actual_end = audio_token_lens[b]
|
||||
if actual_end < audio_tokens.shape[-1]:
|
||||
audio_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return audio_tokens
|
||||
|
||||
def get_logits(
|
||||
self,
|
||||
speech_conditioning_inputs,
|
||||
first_inputs,
|
||||
first_head,
|
||||
second_inputs=None,
|
||||
second_head=None,
|
||||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
if prompt is not None and speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat(
|
||||
[speech_conditioning_inputs, prompt, first_inputs, second_inputs],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
|
||||
elif speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
elif prompt is not None:
|
||||
offset = prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([prompt, first_inputs], dim=1)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
attn_mask = None
|
||||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
return_dict=True,
|
||||
output_attentions=get_attns,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, offset:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
|
||||
|
||||
first_logits = enc[:, : first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0, 2, 1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1] :]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0, 2, 1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
|
||||
prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_inputs,
|
||||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
prompt_codes,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
max_text_len = text_lengths.max()
|
||||
|
||||
# Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
|
||||
# Like [..., 186, 45, 45, 83] where actually it should end with 186.
|
||||
# We take last 3 codes to prevent abrupt ending of the audio.
|
||||
# TODO: This is might need some testing.
|
||||
mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = mel_lengths.max()
|
||||
|
||||
if max_mel_len > audio_codes.shape[-1]:
|
||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||
|
||||
# silence aware lengths, skip the silence tokens at the end of the mel codes.
|
||||
silence = True
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
length = l.item()
|
||||
while silence:
|
||||
if audio_codes[idx, length - 1] != 83:
|
||||
break
|
||||
length -= 1
|
||||
mel_lengths[idx] = length
|
||||
|
||||
# Lovely assertions
|
||||
assert (
|
||||
max_mel_len <= audio_codes.shape[-1]
|
||||
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
|
||||
assert (
|
||||
max_text_len <= text_inputs.shape[-1]
|
||||
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
|
||||
|
||||
# Append stop token to text inputs
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
|
||||
# Append silence token to mel codes
|
||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
|
||||
|
||||
# Pad mel codes with STOP_MEL_TOKEN
|
||||
audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
|
||||
|
||||
# Compute speech conditioning input
|
||||
conds = None
|
||||
if speech_conditioning_input is not None:
|
||||
if not return_latent:
|
||||
# Compute speech conditioning input
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
else:
|
||||
# already computed
|
||||
conds = speech_conditioning_input.unsqueeze(1)
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_mel = torch.ones(
|
||||
audio_codes.shape[0],
|
||||
audio_codes.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
attn_mask_mel[idx, l + 1 :] = 0.0
|
||||
|
||||
# Compute text embeddings + positional embeddings
|
||||
# print(" > text input latent:", text_inputs)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
# Compute mel embeddings + positional embeddings
|
||||
audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
|
||||
|
||||
# Compute prompt embeddings + positional embeddings
|
||||
prompt = self.get_prompts(prompt_codes)
|
||||
|
||||
# prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
|
||||
prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
|
||||
|
||||
# dropout prompt embeddings
|
||||
prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
|
||||
|
||||
# Get logits
|
||||
sub = -4 # don't ask me why 😄
|
||||
if self.training:
|
||||
sub = -1
|
||||
_, audio_logits = self.get_logits(
|
||||
conds,
|
||||
text_emb,
|
||||
self.text_head,
|
||||
audio_emb,
|
||||
self.mel_head,
|
||||
prompt=prompt_emb,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
return audio_logits[:, :sub] # sub to prevent bla.
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
speech_conditioning_latent,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
):
|
||||
"""Compute all the embeddings needed for inference."""
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
print(" > Text inputs:", text_inputs)
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
# prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
|
||||
print(" > Prompt inputs:", prompt_codes)
|
||||
print(" > Prompt inputs shape:", prompt_codes.shape)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
if speech_conditioning_latent is not None:
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
emb = torch.cat([conds, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
return fake_inputs
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
gen = self.inference_model.generate(
|
||||
fake_inputs,
|
||||
bos_token_id=self.start_mel_token,
|
||||
pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token,
|
||||
max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, fake_inputs.shape[1] :], gen
|
||||
return gen[:, fake_inputs.shape[1] :]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,742 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
import torchaudio
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
cond_in_each_up_layer=False,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
|
||||
)
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_weight_norm(self.conv_pre)
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.conds.append(nn.Conv1d(cond_channels, ch, 1))
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
o = o + self.conds[i](g)
|
||||
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se = SELayer(planes, reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out)
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies.
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
layers=[3, 4, 6, 3],
|
||||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
||||
self.inplanes = num_filters[0]
|
||||
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
|
||||
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
|
||||
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
|
||||
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError("Undefined encoder")
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def create_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def new_parameter(self, *size):
|
||||
out = nn.Parameter(torch.FloatTensor(*size))
|
||||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False,
|
||||
use_cuda: bool = False,
|
||||
criterion=None,
|
||||
cache=False,
|
||||
):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"])
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
||||
|
||||
class HifiDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_sample_rate=22050,
|
||||
output_sample_rate=24000,
|
||||
output_hop_length=256,
|
||||
ar_mel_length_compression=1024,
|
||||
decoder_input_dim=1024,
|
||||
resblock_type_decoder="1",
|
||||
resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
resblock_kernel_sizes_decoder=[3, 7, 11],
|
||||
upsample_rates_decoder=[8, 8, 2, 2],
|
||||
upsample_initial_channel_decoder=512,
|
||||
upsample_kernel_sizes_decoder=[16, 16, 4, 4],
|
||||
d_vector_dim=512,
|
||||
cond_d_vector_in_each_upsampling_layer=True,
|
||||
speaker_encoder_audio_config={
|
||||
"fft_size": 512,
|
||||
"win_length": 400,
|
||||
"hop_length": 160,
|
||||
"sample_rate": 16000,
|
||||
"preemphasis": 0.97,
|
||||
"num_mels": 64,
|
||||
},
|
||||
):
|
||||
super().__init__()
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_hop_length = output_hop_length
|
||||
self.ar_mel_length_compression = ar_mel_length_compression
|
||||
self.speaker_encoder_audio_config = speaker_encoder_audio_config
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
decoder_input_dim,
|
||||
1,
|
||||
resblock_type_decoder,
|
||||
resblock_dilation_sizes_decoder,
|
||||
resblock_kernel_sizes_decoder,
|
||||
upsample_kernel_sizes_decoder,
|
||||
upsample_initial_channel_decoder,
|
||||
upsample_rates_decoder,
|
||||
inference_padding=0,
|
||||
cond_channels=d_vector_dim,
|
||||
conv_pre_weight_norm=False,
|
||||
conv_post_weight_norm=False,
|
||||
conv_post_bias=False,
|
||||
cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
self.speaker_encoder = ResNetSpeakerEncoder(
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
log_input=True,
|
||||
use_torch_spec=True,
|
||||
audio_config=speaker_encoder_audio_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, latents, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor (GPT latent).
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
|
||||
z = torch.nn.functional.interpolate(
|
||||
latents.transpose(1, 2),
|
||||
scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
|
||||
mode="linear",
|
||||
).squeeze(1)
|
||||
# upsample to the right sr
|
||||
if self.output_sample_rate != self.input_sample_rate:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z,
|
||||
scale_factor=[self.output_sample_rate / self.input_sample_rate],
|
||||
mode="linear",
|
||||
).squeeze(0)
|
||||
o = self.waveform_decoder(z, g=g)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c, g):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor (GPT latent).
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
return self.forward(c, g=g)
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# remove unused keys
|
||||
state = state["model"]
|
||||
states_keys = list(state.keys())
|
||||
for key in states_keys:
|
||||
if "waveform_decoder." not in key and "speaker_encoder." not in key:
|
||||
del state[key]
|
||||
|
||||
self.load_state_dict(state)
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.waveform_decoder.remove_weight_norm()
|
File diff suppressed because it is too large
Load Diff
|
@ -171,17 +171,6 @@ def multilingual_cleaners(text, lang):
|
|||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
return text
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"}
|
||||
replace = re.compile(
|
||||
|
@ -195,32 +184,6 @@ def remove_extraneous_punctuation(word):
|
|||
return word
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def arabic_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
|
@ -261,7 +224,10 @@ class VoiceBpeTokenizer:
|
|||
txt = " ".join([result["kana"] for result in results])
|
||||
txt = basic_cleaners(txt)
|
||||
elif lang == "en":
|
||||
if txt[:4] == "[en]":
|
||||
txt = txt[4:]
|
||||
txt = english_cleaners(txt)
|
||||
txt = "[en]" + txt
|
||||
elif lang == "ar":
|
||||
txt = arabic_cleaners(txt)
|
||||
elif lang == "zh-cn":
|
||||
|
|
|
@ -726,8 +726,8 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
def pitch_std(self):
|
||||
return self.acoustic_model.pitch_std
|
||||
|
||||
@pitch_mean.setter
|
||||
def pitch_std(self, value): # pylint: disable=function-redefined
|
||||
@pitch_std.setter
|
||||
def pitch_std(self, value):
|
||||
self.acoustic_model.pitch_std = value
|
||||
|
||||
@property
|
||||
|
@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler_D, scheduler_G]
|
||||
|
||||
def on_train_step_start(self, trainer):
|
||||
"""Schedule binary loss weight."""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
|
||||
def on_epoch_end(self, trainer): # pylint: disable=unused-argument
|
||||
# stop updating mean and var
|
||||
# TODO: do the same for F0
|
||||
|
@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
self.train_disc = ( # pylint: disable=attribute-defined-outside-init
|
||||
trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
)
|
||||
|
|
|
@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
|
|||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
init_stream_support()
|
||||
|
||||
def load_audio(audiopath, sr=22050):
|
||||
"""
|
||||
|
@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
|
|||
Args:
|
||||
gpt_batch_size (int): The size of the auto-regressive batch.
|
||||
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
||||
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
|
||||
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
||||
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
|
||||
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
||||
|
||||
For GPT model:
|
||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
|
@ -231,12 +233,12 @@ class XttsArgs(Coqpit):
|
|||
|
||||
gpt_batch_size: int = 1
|
||||
enable_redaction: bool = False
|
||||
lazy_load: bool = True
|
||||
kv_cache: bool = True
|
||||
gpt_checkpoint: str = None
|
||||
clvp_checkpoint: str = None
|
||||
decoder_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
use_hifigan: bool = True
|
||||
|
||||
# XTTS GPT Encoder params
|
||||
tokenizer_file: str = ""
|
||||
|
@ -266,6 +268,15 @@ class XttsArgs(Coqpit):
|
|||
diff_layer_drop: int = 0
|
||||
diff_unconditioned_percentage: int = 0
|
||||
|
||||
# HifiGAN Decoder params
|
||||
input_sample_rate: int = 22050
|
||||
output_sample_rate: int = 24000
|
||||
output_hop_length: int = 256
|
||||
ar_mel_length_compression: int = 1024
|
||||
decoder_input_dim: int = 1024
|
||||
d_vector_dim: int = 512
|
||||
cond_d_vector_in_each_upsampling_layer: bool = True
|
||||
|
||||
# constants
|
||||
duration_const: int = 102400
|
||||
|
||||
|
@ -285,7 +296,6 @@ class Xtts(BaseTTS):
|
|||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.lazy_load = self.args.lazy_load
|
||||
self.mel_stats_path = None
|
||||
self.config = config
|
||||
self.gpt_checkpoint = self.args.gpt_checkpoint
|
||||
|
@ -295,7 +305,6 @@ class Xtts(BaseTTS):
|
|||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
self.gpt = None
|
||||
self.diffusion_decoder = None
|
||||
self.init_models()
|
||||
self.register_buffer("mel_stats", torch.ones(80))
|
||||
|
||||
|
@ -322,40 +331,39 @@ class Xtts(BaseTTS):
|
|||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
)
|
||||
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
|
||||
self.vocoder = UnivNetGenerator()
|
||||
if self.args.use_hifigan:
|
||||
self.hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
else:
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
self.vocoder = UnivNetGenerator()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@contextmanager
|
||||
def lazy_load_model(self, model):
|
||||
"""Context to load a model on demand.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
"""
|
||||
if self.lazy_load:
|
||||
yield model
|
||||
else:
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
|
@ -370,6 +378,7 @@ class Xtts(BaseTTS):
|
|||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_diffusion_cond_latents(
|
||||
self,
|
||||
audio_path,
|
||||
|
@ -389,20 +398,33 @@ class Xtts(BaseTTS):
|
|||
)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(
|
||||
self,
|
||||
audio_path
|
||||
):
|
||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
return speaker_embedding
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
gpt_cond_len=3,
|
||||
):
|
||||
):
|
||||
speaker_embedding = None
|
||||
diffusion_cond_latents = None
|
||||
if self.args.use_hifigan:
|
||||
speaker_embedding = self.get_speaker_embedding(audio_path)
|
||||
else:
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(
|
||||
audio_path,
|
||||
)
|
||||
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
|
||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
@ -447,10 +469,10 @@ class Xtts(BaseTTS):
|
|||
"decoder_sampler": config.decoder_sampler,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.inference(text, ref_audio_path, language, **settings)
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
self,
|
||||
text,
|
||||
ref_audio_path,
|
||||
|
@ -525,6 +547,54 @@ class Xtts(BaseTTS):
|
|||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=temperature,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
decoder_iterations=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
decoder_sampler=decoder_sampler,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -532,74 +602,147 @@ class Xtts(BaseTTS):
|
|||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=decoder_sampler,
|
||||
)
|
||||
if not self.args.use_hifigan:
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=decoder_sampler,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
gpt_codes = gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
gpt_codes = self.gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = self.gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
if self.args.use_hifigan:
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
else:
|
||||
mel = do_spectrogram_diffusion(
|
||||
diffusion,
|
||||
self.diffusion_decoder,
|
||||
diffuser,
|
||||
gpt_latents,
|
||||
diffusion_conditioning,
|
||||
temperature=diffusion_temperature,
|
||||
)
|
||||
with self.lazy_load_model(self.vocoder) as vocoder:
|
||||
wav = vocoder.inference(mel)
|
||||
wav = self.vocoder.inference(mel)
|
||||
|
||||
return {"wav": wav.cpu().numpy().squeeze()}
|
||||
|
||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||
"""Handle chunk formatting in streaming mode"""
|
||||
wav_chunk = wav_gen[:-overlap_len]
|
||||
if wav_gen_prev is not None:
|
||||
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
|
||||
if wav_overlap is not None:
|
||||
crossfade_wav = wav_chunk[:overlap_len]
|
||||
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
|
||||
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
|
||||
wav_chunk[:overlap_len] += crossfade_wav
|
||||
wav_overlap = wav_gen[-overlap_len:]
|
||||
wav_gen_prev = wav_gen
|
||||
return wav_chunk, wav_gen_prev, wav_overlap
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_stream(
|
||||
self,
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
# Streaming
|
||||
stream_chunk_size=20,
|
||||
overlap_wav_len=1024,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
fake_inputs = self.gpt.compute_embeddings(
|
||||
gpt_cond_latent.to(self.device),
|
||||
text_tokens,
|
||||
)
|
||||
gpt_generator = self.gpt.get_generator(
|
||||
fake_inputs=fake_inputs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
length_penalty=float(length_penalty),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
last_tokens = []
|
||||
all_latents = []
|
||||
wav_gen_prev = None
|
||||
wav_overlap = None
|
||||
is_end = False
|
||||
|
||||
while not is_end:
|
||||
try:
|
||||
x, latent = next(gpt_generator)
|
||||
last_tokens += [x]
|
||||
all_latents += [latent]
|
||||
except StopIteration:
|
||||
is_end = True
|
||||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
)
|
||||
last_tokens = []
|
||||
yield wav_chunk
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
||||
|
@ -616,7 +759,14 @@ class Xtts(BaseTTS):
|
|||
super().eval()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
vocab_path=None,
|
||||
eval=True,
|
||||
strict=True,
|
||||
use_deepspeed=False,
|
||||
):
|
||||
"""
|
||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||
|
@ -626,7 +776,7 @@ class Xtts(BaseTTS):
|
|||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
@ -636,19 +786,26 @@ class Xtts(BaseTTS):
|
|||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
|
||||
if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
|
||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
|
||||
if os.path.exists(vocab_path):
|
||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
||||
|
||||
self.init_models()
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
|
||||
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||
self.gpt.eval()
|
||||
self.diffusion_decoder.eval()
|
||||
self.vocoder.eval()
|
||||
|
||||
def train_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -315,11 +316,51 @@ class ModelManager(object):
|
|||
"""Check if the user has agreed to the terms of service"""
|
||||
if "tos_required" in model_item and model_item["tos_required"]:
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
if os.path.exists(tos_path):
|
||||
if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||
config_local = json.load(f)
|
||||
remote_url = None
|
||||
for url in model_item["hf_url"]:
|
||||
if "config.json" in url:
|
||||
remote_url = url
|
||||
break
|
||||
|
||||
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
|
@ -338,28 +379,18 @@ class ModelManager(object):
|
|||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
|
@ -498,13 +529,16 @@ class ModelManager(object):
|
|||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist()[1:]:
|
||||
for file_path in z.namelist():
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove the extracted folder
|
||||
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
||||
if os.path.isfile(src_path):
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove redundant (hidden or not) folders
|
||||
for file_path in z.namelist():
|
||||
if os.path.isdir(os.path.join(output_folder, file_path)):
|
||||
rmtree(os.path.join(output_folder, file_path))
|
||||
|
||||
@staticmethod
|
||||
def _download_tar_file(file_url, output_folder, progress_bar):
|
||||
|
|
|
@ -116,12 +116,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
|
|
@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
|
|||
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
|
||||
You can also mail us at info@coqui.ai.
|
||||
|
||||
Using 🐸TTS API:
|
||||
### Inference
|
||||
#### 🐸TTS API
|
||||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
|
@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
|
|||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en")
|
||||
|
||||
# generate speech by cloning a voice using custom settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en",
|
||||
decoder_iterations=30)
|
||||
```
|
||||
|
||||
Using 🐸TTS Command line:
|
||||
#### 🐸TTS Command line
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
|
||||
|
@ -58,25 +52,85 @@ Using 🐸TTS Command line:
|
|||
--use_cuda true
|
||||
```
|
||||
|
||||
Using model directly:
|
||||
#### model directly
|
||||
|
||||
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
|
||||
|
||||
```console
|
||||
pip install deepspeed==0.8.3
|
||||
```
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
print("Loading model...")
|
||||
config = XttsConfig()
|
||||
config.load_json("/path/to/xtts/config.json")
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
|
||||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
|
||||
print("Inference...")
|
||||
out = model.inference(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=0.7, # Add custom parameters here
|
||||
)
|
||||
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||
```
|
||||
|
||||
|
||||
#### streaming inference
|
||||
|
||||
Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
|
||||
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
|
||||
|
||||
|
||||
```python
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
print("Loading model...")
|
||||
config = XttsConfig()
|
||||
config.load_json("/path/to/xtts/config.json")
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
|
||||
model.cuda()
|
||||
|
||||
outputs = model.synthesize(
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
|
||||
print("Inference...")
|
||||
t0 = time.time()
|
||||
chunks = model.inference_stream(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
config,
|
||||
speaker_wav="/data/TTS-public/_refclips/3.wav",
|
||||
gpt_cond_len=3,
|
||||
language="en",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
)
|
||||
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
print(f"Time to first chunck: {time.time() - t0}")
|
||||
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
|
||||
wav_chuncks.append(chunk)
|
||||
wav = torch.cat(wav_chuncks, dim=0)
|
||||
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import importlib\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm as tqdm\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import soundfile as sf\n",
|
||||
"import pickle\n",
|
||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
|
@ -33,8 +33,8 @@
|
|||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
||||
"# Configure CUDA visibility\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -43,6 +43,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to create directories and file names\n",
|
||||
"def set_filename(wav_path, out_path):\n",
|
||||
" wav_file = os.path.basename(wav_path)\n",
|
||||
" file_name = wav_file.split('.')[0]\n",
|
||||
|
@ -61,6 +62,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Paths and configurations\n",
|
||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
|
@ -73,12 +75,15 @@
|
|||
"QUANTIZE_BIT = None\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"# Check CUDA availability\n",
|
||||
"use_cuda = torch.cuda.is_available()\n",
|
||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"# Load the configuration\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
||||
"print(C['r'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -87,14 +92,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(C['r'])\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"# If the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C and C['characters']:\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"# Load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speaker\n",
|
||||
"# TODO: multiple speakers\n",
|
||||
"model = setup_model(C)\n",
|
||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||
]
|
||||
|
@ -105,11 +109,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the preprocessor based on the dataset\n",
|
||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
||||
"dataset = TTSDataset(\n",
|
||||
" checkpoint[\"config\"][\"r\"],\n",
|
||||
" C,\n",
|
||||
" C.text_cleaner,\n",
|
||||
" False,\n",
|
||||
" ap,\n",
|
||||
|
@ -124,6 +129,24 @@
|
|||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Create log file\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"log_file = open(log_file_path, \"w\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -137,83 +160,85 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"# Start processing with a progress bar\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in tqdm(loader):\n",
|
||||
" # setup input data\n",
|
||||
" text_input = data[0]\n",
|
||||
" text_lengths = data[1]\n",
|
||||
" linear_input = data[3]\n",
|
||||
" mel_input = data[4]\n",
|
||||
" mel_lengths = data[5]\n",
|
||||
" stop_targets = data[6]\n",
|
||||
" item_idx = data[7]\n",
|
||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||
" try:\n",
|
||||
" # setup input data\n",
|
||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" \n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
"\n",
|
||||
" # compute mel specs from linear spec if model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
" # compute mel specs from linear spec if the model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" # for wavernn\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
|
||||
" \n",
|
||||
" # for pwgan\n",
|
||||
" with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" print(np.mean(losses))\n",
|
||||
" print(np.mean(postnet_losses))"
|
||||
" except Exception as e:\n",
|
||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||
"\n",
|
||||
" # Calculate and log mean losses\n",
|
||||
" mean_loss = np.mean(losses)\n",
|
||||
" mean_postnet_loss = np.mean(postnet_losses)\n",
|
||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||
"\n",
|
||||
"# Close the log file\n",
|
||||
"log_file.close()\n",
|
||||
"\n",
|
||||
"# For wavernn\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||
"\n",
|
||||
"# For pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"# Print mean losses\n",
|
||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
[build-system]
|
||||
requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"]
|
||||
requires = [
|
||||
"setuptools",
|
||||
"wheel",
|
||||
"cython~=0.29.30",
|
||||
"numpy>=1.22.0",
|
||||
"packaging",
|
||||
]
|
||||
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
|
@ -7,25 +13,6 @@ max-line-length=120
|
|||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py39']
|
||||
exclude = '''
|
||||
|
||||
(
|
||||
/(
|
||||
\.eggs # exclude a few common directories in the
|
||||
| \.git # root of the project
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
| foo.py # also separately exclude a file named foo.py in
|
||||
# the root of the project
|
||||
)
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
line_length = 120
|
||||
|
|
|
@ -5,26 +5,27 @@ cython==0.29.30
|
|||
scipy>=1.11.2
|
||||
torch>=1.7
|
||||
torchaudio
|
||||
soundfile
|
||||
librosa==0.10.0.*
|
||||
soundfile==0.12.*
|
||||
librosa==0.10.*
|
||||
scikit-learn==1.3.0
|
||||
numba==0.55.1;python_version<"3.9"
|
||||
numba==0.57.0;python_version>="3.9"
|
||||
inflect==5.6.0
|
||||
tqdm
|
||||
anyascii
|
||||
pyyaml
|
||||
fsspec>=2021.04.0
|
||||
aiohttp
|
||||
packaging
|
||||
inflect==5.6.*
|
||||
tqdm==4.64.*
|
||||
anyascii==0.3.*
|
||||
pyyaml==6.*
|
||||
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||
aiohttp==3.8.*
|
||||
packaging==23.1
|
||||
# deps for examples
|
||||
flask
|
||||
flask==2.*
|
||||
# deps for inference
|
||||
pysbd
|
||||
pysbd==0.3.4
|
||||
# deps for notebooks
|
||||
umap-learn==0.5.1
|
||||
pandas
|
||||
umap-learn==0.5.*
|
||||
pandas>=1.4,<2.0
|
||||
# deps for training
|
||||
matplotlib
|
||||
matplotlib==3.7.*
|
||||
# coqui stack
|
||||
trainer
|
||||
# config management
|
||||
|
@ -39,14 +40,14 @@ jamo
|
|||
nltk
|
||||
g2pkk>=0.1.1
|
||||
# deps for bangla
|
||||
bangla==0.0.2
|
||||
bangla
|
||||
bnnumerizer
|
||||
bnunicodenormalizer==0.1.1
|
||||
bnunicodenormalizer
|
||||
#deps for tortoise
|
||||
k_diffusion
|
||||
einops
|
||||
transformers
|
||||
einops==0.6.*
|
||||
transformers==4.33.*
|
||||
#deps for bark
|
||||
encodec
|
||||
encodec==0.1.*
|
||||
# deps for XTTS
|
||||
unidecode
|
||||
unidecode==1.3.*
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def replace_between_markers(content, marker: str, replacement: str) -> str:
|
||||
start_marker = f"<!-- begin-{marker} -->\n\n"
|
||||
end_marker = f"\n\n<!-- end-{marker} -->\n"
|
||||
start_index = content.index(start_marker) + len(start_marker)
|
||||
end_index = content.index(end_marker)
|
||||
content = content[:start_index] + replacement + content[end_index:]
|
||||
return content
|
||||
|
||||
|
||||
def sync_readme():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--check", action="store_true", default=False)
|
||||
args = ap.parse_args()
|
||||
readme_path = Path(__file__).parent.parent / "README.md"
|
||||
orig_content = readme_path.read_text()
|
||||
from TTS.bin.synthesize import description
|
||||
|
||||
new_content = replace_between_markers(orig_content, "tts-readme", description.strip())
|
||||
if args.check:
|
||||
if orig_content != new_content:
|
||||
print("README.md is out of sync; please edit TTS/bin/TTS_README.md and run scripts/sync_readme.py")
|
||||
exit(42)
|
||||
readme_path.write_text(new_content)
|
||||
print("Updated README.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sync_readme()
|
|
@ -0,0 +1,9 @@
|
|||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_readme_up_to_date():
|
||||
root = Path(__file__).parent.parent.parent
|
||||
sync_readme = root / "scripts" / "sync_readme.py"
|
||||
subprocess.check_call([sys.executable, str(sync_readme), "--check"], cwd=root)
|
|
@ -3,13 +3,19 @@ import glob
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
MODELS_WITH_SEP_TESTS = ["bark", "xtts"]
|
||||
MODELS_WITH_SEP_TESTS = [
|
||||
"tts_models/multilingual/multi-dataset/bark",
|
||||
"tts_models/en/multi-dataset/tortoise-v2",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v1",
|
||||
]
|
||||
|
||||
|
||||
def run_models(offset=0, step=1):
|
||||
|
@ -17,7 +23,8 @@ def run_models(offset=0, step=1):
|
|||
print(" > Run synthesizer with all the models.")
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
|
||||
model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS]
|
||||
model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS]
|
||||
print("Model names:", model_names)
|
||||
for model_name in model_names[offset::step]:
|
||||
print(f"\n > Run - {model_name}")
|
||||
model_path, _, _ = manager.download_model(model_name)
|
||||
|
@ -67,23 +74,83 @@ def run_models(offset=0, step=1):
|
|||
|
||||
|
||||
def test_xtts():
|
||||
"""XTTS is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
|
||||
def test_xtts_streaming():
|
||||
"""Testing the new inference_stream method"""
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir=model_path)
|
||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||
|
||||
print("Inference...")
|
||||
chunks = model.inference_stream(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
assert chunk.shape[-1] > 5000
|
||||
wav_chuncks.append(chunk)
|
||||
assert len(wav_chuncks) > 1
|
||||
|
||||
def test_tortoise():
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||
)
|
||||
|
||||
|
||||
def test_bark():
|
||||
"""Bark is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||
)
|
||||
|
||||
|
||||
def test_voice_conversion():
|
||||
|
|
Loading…
Reference in New Issue