mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'coqui-ai:dev' into dev
This commit is contained in:
commit
818aa0eb7e
|
@ -0,0 +1,52 @@
|
|||
name: zoo-tests-tortoise
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise
|
165
README.md
165
README.md
|
@ -146,7 +146,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
|
|||
You can also help us implement more models.
|
||||
|
||||
## Installation
|
||||
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.7, < 3.11.**.
|
||||
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.9, < 3.12.**.
|
||||
|
||||
If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
|
||||
|
||||
|
@ -198,17 +198,18 @@ from TTS.api import TTS
|
|||
# Get device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# List available 🐸TTS models and choose the first one
|
||||
model_name = TTS().list_models()[0]
|
||||
# List available 🐸TTS models
|
||||
print(TTS().list_models())
|
||||
|
||||
# Init TTS
|
||||
tts = TTS(model_name).to(device)
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
|
||||
|
||||
# Run TTS
|
||||
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
|
||||
# Text to speech with a numpy output
|
||||
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
||||
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
||||
# Text to speech list of amplitude values as output
|
||||
wav = tts.tts(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en")
|
||||
# Text to speech to a file
|
||||
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
|
||||
tts.tts_to_file(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||
```
|
||||
|
||||
#### Running a single speaker model
|
||||
|
@ -294,99 +295,135 @@ api.tts_with_vc_to_file(
|
|||
```
|
||||
|
||||
### Command-line `tts`
|
||||
|
||||
<!-- begin-tts-readme -->
|
||||
|
||||
Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
#### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
```
|
||||
```
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
For example:
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run TTS and pipe out the generated TTS wav file data:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
|
||||
```
|
||||
|
||||
- Run TTS and define speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "coqui_studio/<language>/<dataset>/<model_name>" --speed 1.2 --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
#### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose a <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
|
||||
<!-- end-tts-readme -->
|
||||
|
||||
## Directory Structure
|
||||
```
|
||||
|
|
|
@ -5,12 +5,27 @@
|
|||
"xtts_v1": {
|
||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
|
||||
],
|
||||
"default_vocoder": null,
|
||||
"commit": "e9a1953e",
|
||||
"commit": "e5140314",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
},
|
||||
"xtts_v1.1": {
|
||||
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
|
||||
],
|
||||
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
|
||||
"default_vocoder": null,
|
||||
"commit": "82910a63",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
|
@ -728,6 +743,18 @@
|
|||
"license": "Apache 2.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
"be": {
|
||||
"common-voice": {
|
||||
"glow-tts":{
|
||||
"description": "Belarusian GlowTTS model created by @alex73 (Github).",
|
||||
"github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip",
|
||||
"default_vocoder": "vocoder_models/be/common-voice/hifigan",
|
||||
"commit": "c0aabb85",
|
||||
"license": "CC-BY-SA 4.0",
|
||||
"contact": "alex73mail@gmail.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"vocoder_models": {
|
||||
|
@ -879,6 +906,17 @@
|
|||
"commit": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"be": {
|
||||
"common-voice": {
|
||||
"hifigan": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip",
|
||||
"description": "Belarusian HiFiGAN model created by @alex73 (Github).",
|
||||
"author": "@alex73",
|
||||
"license": "CC-BY-SA 4.0",
|
||||
"commit": "c0aabb85"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"voice_conversion_models": {
|
||||
|
@ -894,4 +932,4 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.17.4
|
||||
0.18.2
|
||||
|
|
23
TTS/api.py
23
TTS/api.py
|
@ -17,7 +17,7 @@ class TTS(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
model_name: str = "",
|
||||
model_path: str = None,
|
||||
config_path: str = None,
|
||||
vocoder_path: str = None,
|
||||
|
@ -105,8 +105,8 @@ class TTS(nn.Module):
|
|||
|
||||
@property
|
||||
def is_multi_lingual(self):
|
||||
# TODO: fix this
|
||||
if "xtts" in self.model_name:
|
||||
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
||||
if isinstance(self.model_name, str) and "xtts" in self.model_name:
|
||||
return True
|
||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
||||
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
||||
|
@ -264,6 +264,7 @@ class TTS(nn.Module):
|
|||
language: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = 1.0,
|
||||
pipe_out = None,
|
||||
file_path: str = None,
|
||||
) -> Union[np.ndarray, str]:
|
||||
"""Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API.
|
||||
|
@ -280,6 +281,8 @@ class TTS(nn.Module):
|
|||
with "V1" model. Defaults to None.
|
||||
speed (float, optional):
|
||||
Speed of the speech. Defaults to 1.0.
|
||||
pipe_out (BytesIO, optional):
|
||||
Flag to stdout the generated TTS wav file for shell pipe.
|
||||
file_path (str, optional):
|
||||
Path to save the output file. When None it returns the `np.ndarray` of waveform. Defaults to None.
|
||||
|
||||
|
@ -293,6 +296,7 @@ class TTS(nn.Module):
|
|||
speaker_name=speaker_name,
|
||||
language=language,
|
||||
speed=speed,
|
||||
pipe_out=pipe_out,
|
||||
emotion=emotion,
|
||||
file_path=file_path,
|
||||
)[0]
|
||||
|
@ -355,6 +359,7 @@ class TTS(nn.Module):
|
|||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = 1.0,
|
||||
pipe_out = None,
|
||||
file_path: str = "output.wav",
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -376,6 +381,8 @@ class TTS(nn.Module):
|
|||
Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
|
||||
speed (float, optional):
|
||||
Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
|
||||
pipe_out (BytesIO, optional):
|
||||
Flag to stdout the generated TTS wav file for shell pipe.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
kwargs (dict, optional):
|
||||
|
@ -385,10 +392,16 @@ class TTS(nn.Module):
|
|||
|
||||
if self.csapi is not None:
|
||||
return self.tts_coqui_studio(
|
||||
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
|
||||
text=text,
|
||||
speaker_name=speaker,
|
||||
language=language,
|
||||
emotion=emotion,
|
||||
speed=speed,
|
||||
file_path=file_path,
|
||||
pipe_out=pipe_out,
|
||||
)
|
||||
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||
return file_path
|
||||
|
||||
def voice_conversion(
|
||||
|
|
|
@ -2,15 +2,139 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
description = """
|
||||
Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
#### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run TTS and pipe out the generated TTS wav file data:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
|
||||
```
|
||||
|
||||
- Run TTS and define speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "coqui_studio/<language>/<dataset>/<model_name>" --speed 1.2 --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
#### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose a <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
|
@ -24,92 +148,6 @@ def str2bool(v):
|
|||
|
||||
|
||||
def main():
|
||||
description = """Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
## Example Runs
|
||||
|
||||
### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Query info for model info by idx:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS"
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose as <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
||||
# documentation in sync more easily.
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description.replace(" ```\n", ""),
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
|
@ -203,6 +241,20 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
help="Language to condition the model with. Only available for 🐸Coqui Studio `XTTS-multilingual` model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pipe_out",
|
||||
help="stdout the generated TTS wav file for shell pipe.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed",
|
||||
type=float,
|
||||
help="Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
|
@ -310,162 +362,177 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
if not any(check_args):
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
api = TTS()
|
||||
pipe_out = sys.stdout if args.pipe_out else None
|
||||
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
encoder_config_path = None
|
||||
vc_path = None
|
||||
vc_config_path = None
|
||||
model_dir = None
|
||||
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
|
||||
# Late-import to make things load faster
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
# CASE1 #list : list pre-trained TTS models
|
||||
if args.list_models:
|
||||
manager.add_cs_api_models(api.list_models())
|
||||
manager.list_models()
|
||||
sys.exit()
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
api = TTS()
|
||||
|
||||
# CASE2 #info : model info for pre-trained TTS models
|
||||
if args.model_info_by_idx:
|
||||
model_query = args.model_info_by_idx
|
||||
manager.model_info_by_idx(model_query)
|
||||
sys.exit()
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
encoder_config_path = None
|
||||
vc_path = None
|
||||
vc_config_path = None
|
||||
model_dir = None
|
||||
|
||||
if args.model_info_by_name:
|
||||
model_query_full_name = args.model_info_by_name
|
||||
manager.model_info_by_full_name(model_query_full_name)
|
||||
sys.exit()
|
||||
# CASE1 #list : list pre-trained TTS models
|
||||
if args.list_models:
|
||||
manager.add_cs_api_models(api.list_models())
|
||||
manager.list_models()
|
||||
sys.exit()
|
||||
|
||||
# CASE3: TTS with coqui studio models
|
||||
if "coqui_studio" in args.model_name:
|
||||
print(" > Using 🐸Coqui Studio model: ", args.model_name)
|
||||
api = TTS(model_name=args.model_name, cs_api_model=args.cs_model)
|
||||
api.tts_to_file(text=args.text, emotion=args.emotion, file_path=args.out_path, language=args.language)
|
||||
print(" > Saving output to ", args.out_path)
|
||||
return
|
||||
# CASE2 #info : model info for pre-trained TTS models
|
||||
if args.model_info_by_idx:
|
||||
model_query = args.model_info_by_idx
|
||||
manager.model_info_by_idx(model_query)
|
||||
sys.exit()
|
||||
|
||||
# CASE4: load pre-trained model paths
|
||||
if args.model_name is not None and not args.model_path:
|
||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||
# tts model
|
||||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
if args.model_info_by_name:
|
||||
model_query_full_name = args.model_info_by_name
|
||||
manager.model_info_by_full_name(model_query_full_name)
|
||||
sys.exit()
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
vc_path = model_path
|
||||
vc_config_path = config_path
|
||||
# CASE3: TTS with coqui studio models
|
||||
if "coqui_studio" in args.model_name:
|
||||
print(" > Using 🐸Coqui Studio model: ", args.model_name)
|
||||
api = TTS(model_name=args.model_name, cs_api_model=args.cs_model)
|
||||
api.tts_to_file(
|
||||
text=args.text,
|
||||
emotion=args.emotion,
|
||||
file_path=args.out_path,
|
||||
language=args.language,
|
||||
speed=args.speed,
|
||||
pipe_out=pipe_out,
|
||||
)
|
||||
print(" > Saving output to ", args.out_path)
|
||||
return
|
||||
|
||||
# tts model with multiple files to be loaded from the directory path
|
||||
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
|
||||
model_dir = model_path
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
args.vocoder_name = None
|
||||
# CASE4: load pre-trained model paths
|
||||
if args.model_name is not None and not args.model_path:
|
||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||
# tts model
|
||||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
|
||||
# load vocoder
|
||||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
vc_path = model_path
|
||||
vc_config_path = config_path
|
||||
|
||||
# CASE5: set custom model paths
|
||||
if args.model_path is not None:
|
||||
tts_path = args.model_path
|
||||
tts_config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
# tts model with multiple files to be loaded from the directory path
|
||||
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
|
||||
model_dir = model_path
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
args.vocoder_name = None
|
||||
|
||||
if args.vocoder_path is not None:
|
||||
vocoder_path = args.vocoder_path
|
||||
vocoder_config_path = args.vocoder_config_path
|
||||
# load vocoder
|
||||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
if args.encoder_path is not None:
|
||||
encoder_path = args.encoder_path
|
||||
encoder_config_path = args.encoder_config_path
|
||||
# CASE5: set custom model paths
|
||||
if args.model_path is not None:
|
||||
tts_path = args.model_path
|
||||
tts_config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
|
||||
device = args.device
|
||||
if args.use_cuda:
|
||||
device = "cuda"
|
||||
if args.vocoder_path is not None:
|
||||
vocoder_path = args.vocoder_path
|
||||
vocoder_config_path = args.vocoder_config_path
|
||||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
tts_path,
|
||||
tts_config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
encoder_config_path,
|
||||
vc_path,
|
||||
vc_config_path,
|
||||
model_dir,
|
||||
args.voice_dir,
|
||||
).to(device)
|
||||
if args.encoder_path is not None:
|
||||
encoder_path = args.encoder_path
|
||||
encoder_config_path = args.encoder_config_path
|
||||
|
||||
# query speaker ids of a multi-speaker model.
|
||||
if args.list_speaker_idxs:
|
||||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||
return
|
||||
device = args.device
|
||||
if args.use_cuda:
|
||||
device = "cuda"
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||
return
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
tts_path,
|
||||
tts_config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
encoder_config_path,
|
||||
vc_path,
|
||||
vc_config_path,
|
||||
model_dir,
|
||||
args.voice_dir,
|
||||
).to(device)
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
||||
)
|
||||
return
|
||||
# query speaker ids of a multi-speaker model.
|
||||
if args.list_speaker_idxs:
|
||||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||
return
|
||||
|
||||
# RUN THE SYNTHESIS
|
||||
if args.text:
|
||||
print(" > Text: {}".format(args.text))
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||
return
|
||||
|
||||
# kick it
|
||||
if tts_path is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
speaker_name=args.speaker_idx,
|
||||
language_name=args.language_idx,
|
||||
speaker_wav=args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
elif vc_path is not None:
|
||||
wav = synthesizer.voice_conversion(
|
||||
source_wav=args.source_wav,
|
||||
target_wav=args.target_wav,
|
||||
)
|
||||
elif model_dir is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
|
||||
)
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
||||
)
|
||||
return
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
synthesizer.save_wav(wav, args.out_path)
|
||||
# RUN THE SYNTHESIS
|
||||
if args.text:
|
||||
print(" > Text: {}".format(args.text))
|
||||
|
||||
# kick it
|
||||
if tts_path is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
speaker_name=args.speaker_idx,
|
||||
language_name=args.language_idx,
|
||||
speaker_wav=args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
elif vc_path is not None:
|
||||
wav = synthesizer.voice_conversion(
|
||||
source_wav=args.source_wav,
|
||||
target_wav=args.target_wav,
|
||||
)
|
||||
elif model_dir is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
|
||||
)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -9,6 +9,8 @@ import numpy as np
|
|||
import requests
|
||||
from scipy.io import wavfile
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
|
||||
|
||||
class Speaker(object):
|
||||
"""Convert dict to object."""
|
||||
|
@ -288,6 +290,7 @@ class CS_API:
|
|||
speaker_id=None,
|
||||
emotion=None,
|
||||
speed=1.0,
|
||||
pipe_out=None,
|
||||
language=None,
|
||||
file_path: str = None,
|
||||
) -> str:
|
||||
|
@ -300,6 +303,7 @@ class CS_API:
|
|||
speaker_id (str): Speaker ID. If None, the speaker name is used.
|
||||
emotion (str): Emotion of the speaker. One of "Neutral", "Happy", "Sad", "Angry", "Dull".
|
||||
speed (float): Speed of the speech. 1.0 is normal speed.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||
supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
||||
file_path (str): Path to save the file. If None, a temporary file is created.
|
||||
|
@ -307,7 +311,7 @@ class CS_API:
|
|||
if file_path is None:
|
||||
file_path = tempfile.mktemp(".wav")
|
||||
wav, sr = self.tts(text, speaker_name, speaker_id, emotion, speed, language)
|
||||
wavfile.write(file_path, sr, wav)
|
||||
save_wav(wav=wav, path=file_path, sample_rate=sr, pipe_out=pipe_out)
|
||||
return file_path
|
||||
|
||||
|
||||
|
|
|
@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig):
|
|||
)
|
||||
|
||||
# inference params
|
||||
temperature: float = 0.2
|
||||
temperature: float = 0.85
|
||||
length_penalty: float = 1.0
|
||||
repetition_penalty: float = 2.0
|
||||
top_k: int = 50
|
||||
top_p: float = 0.8
|
||||
top_p: float = 0.85
|
||||
cond_free_k: float = 2.0
|
||||
diffusion_temperature: float = 1.0
|
||||
num_gpt_outputs: int = 16
|
||||
num_gpt_outputs: int = 1
|
||||
decoder_iterations: int = 30
|
||||
decoder_sampler: str = "ddim"
|
||||
|
|
|
@ -1085,31 +1085,6 @@ class GaussianDiffusion:
|
|||
}
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
|
|
@ -5,9 +5,13 @@ from tokenizers import Tokenizer
|
|||
|
||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
|
||||
)
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, vocab_str=None):
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
|
|
@ -1170,31 +1170,6 @@ class GaussianDiffusion:
|
|||
}
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
|
|
@ -172,7 +172,7 @@ class GPT(nn.Module):
|
|||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_gpt_for_inference(self, kv_cache=True):
|
||||
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
|
||||
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
|
@ -195,6 +195,17 @@ class GPT(nn.Module):
|
|||
)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(
|
||||
model=self.gpt_inference.half(), # Transformers models
|
||||
mp_size=1, # Number of GPU
|
||||
dtype=torch.float32, # desired data type of output
|
||||
replace_method="auto", # Lets DS autmatically identify the layer to replace
|
||||
replace_with_kernel_inject=True, # replace the model with the kernel injector
|
||||
)
|
||||
self.gpt_inference = self.ds_engine.module.eval()
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
|
@ -543,3 +554,14 @@ class GPT(nn.Module):
|
|||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
||||
return gen[:, gpt_inputs.shape[1] :]
|
||||
|
||||
def get_generator(self, fake_inputs, **hf_generate_kwargs):
|
||||
return self.gpt_inference.generate_stream(
|
||||
fake_inputs,
|
||||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
do_stream=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,658 +0,0 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||
|
||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.pos_embedding = pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.final_norm = norm
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
def store_prefix_emb(self, prefix_emb):
|
||||
self.cached_prefix_emb = prefix_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
||||
if not self.kv_cache:
|
||||
past_key_values = None
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values is not None:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert self.cached_prefix_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
||||
|
||||
# Create embedding
|
||||
prefix_len = self.cached_prefix_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
gen_inputs = input_ids[:, prefix_len:]
|
||||
gen_emb = self.embeddings(gen_inputs)
|
||||
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
||||
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
||||
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
||||
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
||||
)
|
||||
else:
|
||||
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
||||
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_channels)
|
||||
nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, x):
|
||||
seq_len = x.shape[1]
|
||||
if self.relative:
|
||||
start = torch.randint(seq_len, (1,), device=x.device).item()
|
||||
positions = torch.arange(start, start + seq_len, device=x.device)
|
||||
else:
|
||||
positions = torch.arange(seq_len, device=x.device)
|
||||
return self.emb(positions)
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
|
||||
"""
|
||||
Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
|
||||
|
||||
Args:
|
||||
layers (int): Number of layers in the GPT-2 model.
|
||||
model_channels (int): Dimension of the GPT-2 model.
|
||||
heads (int): Number of heads in the GPT-2 model.
|
||||
max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
|
||||
max_text_seq_len (int): Maximum sequence length for the text.
|
||||
max_prompt_len (int): Maximum length of the prompt.
|
||||
checkpointing (bool): Whether to use gradient checkpointing.
|
||||
|
||||
Returns:
|
||||
gpt (GPT2Model): GPT-2 model.
|
||||
mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
|
||||
text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
|
||||
"""
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=123,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_channels,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
del gpt.wpe
|
||||
del gpt.wte
|
||||
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
|
||||
|
||||
audio_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
|
||||
return gpt, audio_pos_emb, text_pos_emb
|
||||
|
||||
|
||||
class XTTSGPTEncoder(nn.Module):
|
||||
"""XTTS GPT Encoder model implementation.
|
||||
Args:
|
||||
start_text_token (int): Index of the start token in the text vocabulary.
|
||||
stop_text_token (int): Index of the stop token in the text vocabulary.
|
||||
n_layers (int): Number of layers in the GPT-2 model.
|
||||
n_model_channels (int): Dimension of the GPT-2 model.
|
||||
n_heads (int): Number of heads in the GPT-2 model.
|
||||
max_text_tokens (int): Maximum number of text tokens.
|
||||
max_audio_tokens (int): Maximum number of audio tokens.
|
||||
max_prompt_tokens (int): Maximum number of prompt tokens.
|
||||
audio_len_compression (int): Compression factor for the audio length.
|
||||
number_text_tokens (int): Number of text tokens.
|
||||
number_audio_codes (int): Number of audio codes.
|
||||
start_mel_token (int): Index of the start token in the mel code vocabulary.
|
||||
stop_mel_token (int): Index of the stop token in the mel code vocabulary.
|
||||
checkpointing (bool): Whether or not to use gradient checkpointing at training.
|
||||
"""
|
||||
|
||||
_inference_flag = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_text_token=261,
|
||||
stop_text_token=0,
|
||||
n_layers=8,
|
||||
n_model_channels=512,
|
||||
n_heads=8,
|
||||
max_text_tokens=120,
|
||||
max_audio_tokens=250,
|
||||
max_prompt_tokens=70,
|
||||
audio_len_compression=1024,
|
||||
number_text_tokens=256,
|
||||
number_audio_codes=8194,
|
||||
start_mel_token=8192,
|
||||
stop_mel_token=8193,
|
||||
checkpointing=True,
|
||||
label_smoothing=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.label_smoothing = label_smoothing
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = start_text_token
|
||||
self.stop_text_token = stop_text_token
|
||||
self.number_audio_codes = number_audio_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
self.start_prompt_token = start_mel_token
|
||||
self.stop_prompt_token = stop_mel_token
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.n_model_channels = n_model_channels
|
||||
self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
self.audio_len_compression = audio_len_compression
|
||||
|
||||
# embedding layers
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
|
||||
self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
|
||||
|
||||
# initialize the GPT-2 model
|
||||
(
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.text_pos_embedding,
|
||||
) = init_gpt(
|
||||
n_layers,
|
||||
n_model_channels,
|
||||
n_heads,
|
||||
self.max_audio_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
)
|
||||
|
||||
# output layers
|
||||
self.final_norm = nn.LayerNorm(n_model_channels)
|
||||
self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
|
||||
self._inference_flag = True
|
||||
seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_audio_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.n_model_channels,
|
||||
n_layer=self.n_layers,
|
||||
n_head=self.n_heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.audio_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
self.gpt.wte = self.audio_embedding
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
for b in range(len(audio_token_lens)):
|
||||
actual_end = audio_token_lens[b]
|
||||
if actual_end < audio_tokens.shape[-1]:
|
||||
audio_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return audio_tokens
|
||||
|
||||
def get_logits(
|
||||
self,
|
||||
speech_conditioning_inputs,
|
||||
first_inputs,
|
||||
first_head,
|
||||
second_inputs=None,
|
||||
second_head=None,
|
||||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
if prompt is not None and speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat(
|
||||
[speech_conditioning_inputs, prompt, first_inputs, second_inputs],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
|
||||
elif speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
elif prompt is not None:
|
||||
offset = prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([prompt, first_inputs], dim=1)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
attn_mask = None
|
||||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
return_dict=True,
|
||||
output_attentions=get_attns,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, offset:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
|
||||
|
||||
first_logits = enc[:, : first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0, 2, 1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1] :]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0, 2, 1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
|
||||
prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_inputs,
|
||||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
prompt_codes,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
max_text_len = text_lengths.max()
|
||||
|
||||
# Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
|
||||
# Like [..., 186, 45, 45, 83] where actually it should end with 186.
|
||||
# We take last 3 codes to prevent abrupt ending of the audio.
|
||||
# TODO: This is might need some testing.
|
||||
mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = mel_lengths.max()
|
||||
|
||||
if max_mel_len > audio_codes.shape[-1]:
|
||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||
|
||||
# silence aware lengths, skip the silence tokens at the end of the mel codes.
|
||||
silence = True
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
length = l.item()
|
||||
while silence:
|
||||
if audio_codes[idx, length - 1] != 83:
|
||||
break
|
||||
length -= 1
|
||||
mel_lengths[idx] = length
|
||||
|
||||
# Lovely assertions
|
||||
assert (
|
||||
max_mel_len <= audio_codes.shape[-1]
|
||||
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
|
||||
assert (
|
||||
max_text_len <= text_inputs.shape[-1]
|
||||
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
|
||||
|
||||
# Append stop token to text inputs
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
|
||||
# Append silence token to mel codes
|
||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
|
||||
|
||||
# Pad mel codes with STOP_MEL_TOKEN
|
||||
audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
|
||||
|
||||
# Compute speech conditioning input
|
||||
conds = None
|
||||
if speech_conditioning_input is not None:
|
||||
if not return_latent:
|
||||
# Compute speech conditioning input
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
else:
|
||||
# already computed
|
||||
conds = speech_conditioning_input.unsqueeze(1)
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_mel = torch.ones(
|
||||
audio_codes.shape[0],
|
||||
audio_codes.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
attn_mask_mel[idx, l + 1 :] = 0.0
|
||||
|
||||
# Compute text embeddings + positional embeddings
|
||||
# print(" > text input latent:", text_inputs)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
# Compute mel embeddings + positional embeddings
|
||||
audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
|
||||
|
||||
# Compute prompt embeddings + positional embeddings
|
||||
prompt = self.get_prompts(prompt_codes)
|
||||
|
||||
# prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
|
||||
prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
|
||||
|
||||
# dropout prompt embeddings
|
||||
prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
|
||||
|
||||
# Get logits
|
||||
sub = -4 # don't ask me why 😄
|
||||
if self.training:
|
||||
sub = -1
|
||||
_, audio_logits = self.get_logits(
|
||||
conds,
|
||||
text_emb,
|
||||
self.text_head,
|
||||
audio_emb,
|
||||
self.mel_head,
|
||||
prompt=prompt_emb,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
return audio_logits[:, :sub] # sub to prevent bla.
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
speech_conditioning_latent,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
):
|
||||
"""Compute all the embeddings needed for inference."""
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
print(" > Text inputs:", text_inputs)
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
# prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
|
||||
print(" > Prompt inputs:", prompt_codes)
|
||||
print(" > Prompt inputs shape:", prompt_codes.shape)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
if speech_conditioning_latent is not None:
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
emb = torch.cat([conds, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
return fake_inputs
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
gen = self.inference_model.generate(
|
||||
fake_inputs,
|
||||
bos_token_id=self.start_mel_token,
|
||||
pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token,
|
||||
max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, fake_inputs.shape[1] :], gen
|
||||
return gen[:, fake_inputs.shape[1] :]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,742 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
import torchaudio
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
cond_in_each_up_layer=False,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
|
||||
)
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_weight_norm(self.conv_pre)
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.conds.append(nn.Conv1d(cond_channels, ch, 1))
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
o = o + self.conds[i](g)
|
||||
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se = SELayer(planes, reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out)
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies.
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
layers=[3, 4, 6, 3],
|
||||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
||||
self.inplanes = num_filters[0]
|
||||
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
|
||||
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
|
||||
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
|
||||
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError("Undefined encoder")
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def create_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def new_parameter(self, *size):
|
||||
out = nn.Parameter(torch.FloatTensor(*size))
|
||||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False,
|
||||
use_cuda: bool = False,
|
||||
criterion=None,
|
||||
cache=False,
|
||||
):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"])
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
||||
|
||||
class HifiDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_sample_rate=22050,
|
||||
output_sample_rate=24000,
|
||||
output_hop_length=256,
|
||||
ar_mel_length_compression=1024,
|
||||
decoder_input_dim=1024,
|
||||
resblock_type_decoder="1",
|
||||
resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
resblock_kernel_sizes_decoder=[3, 7, 11],
|
||||
upsample_rates_decoder=[8, 8, 2, 2],
|
||||
upsample_initial_channel_decoder=512,
|
||||
upsample_kernel_sizes_decoder=[16, 16, 4, 4],
|
||||
d_vector_dim=512,
|
||||
cond_d_vector_in_each_upsampling_layer=True,
|
||||
speaker_encoder_audio_config={
|
||||
"fft_size": 512,
|
||||
"win_length": 400,
|
||||
"hop_length": 160,
|
||||
"sample_rate": 16000,
|
||||
"preemphasis": 0.97,
|
||||
"num_mels": 64,
|
||||
},
|
||||
):
|
||||
super().__init__()
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_hop_length = output_hop_length
|
||||
self.ar_mel_length_compression = ar_mel_length_compression
|
||||
self.speaker_encoder_audio_config = speaker_encoder_audio_config
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
decoder_input_dim,
|
||||
1,
|
||||
resblock_type_decoder,
|
||||
resblock_dilation_sizes_decoder,
|
||||
resblock_kernel_sizes_decoder,
|
||||
upsample_kernel_sizes_decoder,
|
||||
upsample_initial_channel_decoder,
|
||||
upsample_rates_decoder,
|
||||
inference_padding=0,
|
||||
cond_channels=d_vector_dim,
|
||||
conv_pre_weight_norm=False,
|
||||
conv_post_weight_norm=False,
|
||||
conv_post_bias=False,
|
||||
cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
self.speaker_encoder = ResNetSpeakerEncoder(
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
log_input=True,
|
||||
use_torch_spec=True,
|
||||
audio_config=speaker_encoder_audio_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, latents, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor (GPT latent).
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
|
||||
z = torch.nn.functional.interpolate(
|
||||
latents.transpose(1, 2),
|
||||
scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
|
||||
mode="linear",
|
||||
).squeeze(1)
|
||||
# upsample to the right sr
|
||||
if self.output_sample_rate != self.input_sample_rate:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z,
|
||||
scale_factor=[self.output_sample_rate / self.input_sample_rate],
|
||||
mode="linear",
|
||||
).squeeze(0)
|
||||
o = self.waveform_decoder(z, g=g)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c, g):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor (GPT latent).
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
return self.forward(c, g=g)
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# remove unused keys
|
||||
state = state["model"]
|
||||
states_keys = list(state.keys())
|
||||
for key in states_keys:
|
||||
if "waveform_decoder." not in key and "speaker_encoder." not in key:
|
||||
del state[key]
|
||||
|
||||
self.load_state_dict(state)
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.waveform_decoder.remove_weight_norm()
|
File diff suppressed because it is too large
Load Diff
|
@ -1,243 +1,468 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
import inflect
|
||||
import pandas as pd
|
||||
import pypinyin
|
||||
import torch
|
||||
from num2words import num2words
|
||||
from tokenizers import Tokenizer
|
||||
from unidecode import unidecode
|
||||
|
||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||
import pypinyin
|
||||
from num2words import num2words
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + " dollars" # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
],
|
||||
"es": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("sra", "señora"),
|
||||
("sr", "señor"),
|
||||
("dr", "doctor"),
|
||||
("dra", "doctora"),
|
||||
("st", "santo"),
|
||||
("co", "compañía"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitada"),
|
||||
]
|
||||
],
|
||||
"fr": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mme", "madame"),
|
||||
("mr", "monsieur"),
|
||||
("dr", "docteur"),
|
||||
("st", "saint"),
|
||||
("co", "compagnie"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitée"),
|
||||
]
|
||||
],
|
||||
"de": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("fr", "frau"),
|
||||
("dr", "doktor"),
|
||||
("st", "sankt"),
|
||||
("co", "firma"),
|
||||
("jr", "junior"),
|
||||
]
|
||||
],
|
||||
"pt": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("sra", "senhora"),
|
||||
("sr", "senhor"),
|
||||
("dr", "doutor"),
|
||||
("dra", "doutora"),
|
||||
("st", "santo"),
|
||||
("co", "companhia"),
|
||||
("jr", "júnior"),
|
||||
("ltd", "limitada"),
|
||||
]
|
||||
],
|
||||
"it": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
#("sig.ra", "signora"),
|
||||
("sig", "signore"),
|
||||
("dr", "dottore"),
|
||||
("st", "santo"),
|
||||
("co", "compagnia"),
|
||||
("jr", "junior"),
|
||||
("ltd", "limitata"),
|
||||
]
|
||||
],
|
||||
"pl": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("p", "pani"),
|
||||
("m", "pan"),
|
||||
("dr", "doktor"),
|
||||
("sw", "święty"),
|
||||
("jr", "junior"),
|
||||
]
|
||||
],
|
||||
"ar": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# There are not many common abbreviations in Arabic as in English.
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
]
|
||||
],
|
||||
"cs": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dr", "doktor"), # doctor
|
||||
("ing", "inženýr"), # engineer
|
||||
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
||||
# Other abbreviations would be specialized and not as common.
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("г-жа", "госпожа"), # Mrs.
|
||||
("г-н", "господин"), # Mr.
|
||||
("д-р", "доктор"), # doctor
|
||||
# Other abbreviations are less common or specialized.
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dhr", "de heer"), # Mr.
|
||||
("mevr", "mevrouw"), # Mrs.
|
||||
("dr", "dokter"), # doctor
|
||||
("jhr", "jonkheer"), # young lord or nobleman
|
||||
# Dutch uses more abbreviations, but these are the most common ones.
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("b", "bay"), # Mr.
|
||||
("byk", "büyük"), # büyük
|
||||
("dr", "doktor"), # doctor
|
||||
# Add other Turkish abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
def expand_abbreviations_multilingual(text, lang='en'):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
_symbols_multilingual = {
|
||||
'en': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
("@", " at "),
|
||||
("%", " percent "),
|
||||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree ")
|
||||
]
|
||||
],
|
||||
'es': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " y "),
|
||||
("@", " arroba "),
|
||||
("%", " por ciento "),
|
||||
("#", " numeral "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " grados ")
|
||||
]
|
||||
],
|
||||
'fr': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " et "),
|
||||
("@", " arobase "),
|
||||
("%", " pour cent "),
|
||||
("#", " dièse "),
|
||||
("$", " dollar "),
|
||||
("£", " livre "),
|
||||
("°", " degrés ")
|
||||
]
|
||||
],
|
||||
'de': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " und "),
|
||||
("@", " at "),
|
||||
("%", " prozent "),
|
||||
("#", " raute "),
|
||||
("$", " dollar "),
|
||||
("£", " pfund "),
|
||||
("°", " grad ")
|
||||
]
|
||||
],
|
||||
'pt': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
("@", " arroba "),
|
||||
("%", " por cento "),
|
||||
("#", " cardinal "),
|
||||
("$", " dólar "),
|
||||
("£", " libra "),
|
||||
("°", " graus ")
|
||||
]
|
||||
],
|
||||
'it': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
("@", " chiocciola "),
|
||||
("%", " per cento "),
|
||||
("#", " cancelletto "),
|
||||
("$", " dollaro "),
|
||||
("£", " sterlina "),
|
||||
("°", " gradi ")
|
||||
]
|
||||
],
|
||||
'pl': [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " i "),
|
||||
("@", " małpa "),
|
||||
("%", " procent "),
|
||||
("#", " krzyżyk "),
|
||||
("$", " dolar "),
|
||||
("£", " funt "),
|
||||
("°", " stopnie ")
|
||||
]
|
||||
],
|
||||
"ar": [
|
||||
# Arabic
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " و "),
|
||||
("@", " على "),
|
||||
("%", " في المئة "),
|
||||
("#", " رقم "),
|
||||
("$", " دولار "),
|
||||
("£", " جنيه "),
|
||||
("°", " درجة ")
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
# Chinese
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " 和 "),
|
||||
("@", " 在 "),
|
||||
("%", " 百分之 "),
|
||||
("#", " 号 "),
|
||||
("$", " 美元 "),
|
||||
("£", " 英镑 "),
|
||||
("°", " 度 ")
|
||||
]
|
||||
],
|
||||
"cs": [
|
||||
# Czech
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " a "),
|
||||
("@", " na "),
|
||||
("%", " procento "),
|
||||
("#", " křížek "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " stupně ")
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
# Russian
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " и "),
|
||||
("@", " собака "),
|
||||
("%", " процентов "),
|
||||
("#", " номер "),
|
||||
("$", " доллар "),
|
||||
("£", " фунт "),
|
||||
("°", " градус ")
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
# Dutch
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " en "),
|
||||
("@", " bij "),
|
||||
("%", " procent "),
|
||||
("#", " hekje "),
|
||||
("$", " dollar "),
|
||||
("£", " pond "),
|
||||
("°", " graden ")
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " ve "),
|
||||
("@", " at "),
|
||||
("%", " yüzde "),
|
||||
("#", " diyez "),
|
||||
("$", " dolar "),
|
||||
("£", " sterlin "),
|
||||
("°", " derece ")
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
def expand_symbols_multilingual(text, lang='en'):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(' ', ' ') # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
_ordinal_re = {
|
||||
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
|
||||
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
|
||||
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
|
||||
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
|
||||
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
|
||||
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
|
||||
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
|
||||
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
||||
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
||||
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
||||
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
def _expand_decimal_point(m, lang='en'):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def _expand_currency(m, lang='en', currency='USD'):
|
||||
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
|
||||
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
"es": " con ",
|
||||
"fr": " et ",
|
||||
"de": " und ",
|
||||
"pt": " e ",
|
||||
"it": " e ",
|
||||
"pl": ", ",
|
||||
"cs": ", ",
|
||||
"ru": ", ",
|
||||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
last_and = full_amount.rfind(and_equivalents[lang])
|
||||
if last_and != -1:
|
||||
full_amount = full_amount[:last_and]
|
||||
|
||||
return full_amount
|
||||
|
||||
def _expand_ordinal(m, lang='en'):
|
||||
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def _expand_number(m, lang='en'):
|
||||
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def expand_numbers_multilingual(text, lang='en'):
|
||||
if lang == "zh-cn":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
if lang in ["en", "ru"]:
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
|
||||
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
|
||||
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
|
||||
except:
|
||||
pass
|
||||
if lang != "tr":
|
||||
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers_multilang(text, lang):
|
||||
# TODO: Handle text more carefully. Currently, it just converts numbers without any context.
|
||||
# Find all numbers in the input string
|
||||
numbers = re.findall(r"\d+", text)
|
||||
|
||||
# Transliterate the numbers to text
|
||||
for num in numbers:
|
||||
transliterated_num = "".join(num2words(num, lang=lang))
|
||||
text = text.replace(num, transliterated_num, 1)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = lowercase(text)
|
||||
text = expand_numbers_multilang(text, lang)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace('"', '')
|
||||
if lang=="tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = expand_numbers_multilingual(text, lang)
|
||||
text = expand_abbreviations_multilingual(text, lang)
|
||||
text = expand_symbols_multilingual(text, lang=lang)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
return text
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"}
|
||||
replace = re.compile(
|
||||
"|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
|
||||
)
|
||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||
|
||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||
extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
|
||||
word = extraneous.sub("", word)
|
||||
return word
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
def chinese_transliterate(text):
|
||||
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
|
||||
|
||||
def arabic_cleaners(text):
|
||||
def japanese_cleaners(text, katsu):
|
||||
text = katsu.romaji(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = "".join(
|
||||
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, preprocess=None):
|
||||
self.tokenizer = None
|
||||
self.katsu = None
|
||||
|
||||
if vocab_file is not None:
|
||||
with open(vocab_file, "r", encoding="utf-8") as f:
|
||||
|
@ -253,21 +478,20 @@ class VoiceBpeTokenizer:
|
|||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang == "ja":
|
||||
import pykakasi
|
||||
|
||||
kks = pykakasi.kakasi()
|
||||
results = kks.convert(txt)
|
||||
txt = " ".join([result["kana"] for result in results])
|
||||
txt = basic_cleaners(txt)
|
||||
elif lang == "en":
|
||||
txt = english_cleaners(txt)
|
||||
elif lang == "ar":
|
||||
txt = arabic_cleaners(txt)
|
||||
elif lang == "zh-cn":
|
||||
txt = chinese_cleaners(txt)
|
||||
else:
|
||||
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang == "zh-cn":
|
||||
txt = chinese_transliterate(txt)
|
||||
elif lang == "ja":
|
||||
assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token."
|
||||
txt = txt[4:]
|
||||
if self.katsu is None:
|
||||
import cutlet
|
||||
self.katsu = cutlet.Cutlet()
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
txt = "[ja]" + txt
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
|
@ -284,3 +508,9 @@ class VoiceBpeTokenizer:
|
|||
txt = txt.replace("[STOP]", "")
|
||||
txt = txt.replace("[UNK]", "")
|
||||
return txt
|
||||
|
||||
def __len__(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
def get_number_tokens(self):
|
||||
return max(self.tokenizer.get_vocab().values()) + 1
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -726,8 +726,8 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
def pitch_std(self):
|
||||
return self.acoustic_model.pitch_std
|
||||
|
||||
@pitch_mean.setter
|
||||
def pitch_std(self, value): # pylint: disable=function-redefined
|
||||
@pitch_std.setter
|
||||
def pitch_std(self, value):
|
||||
self.acoustic_model.pitch_std = value
|
||||
|
||||
@property
|
||||
|
@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler_D, scheduler_G]
|
||||
|
||||
def on_train_step_start(self, trainer):
|
||||
"""Schedule binary loss weight."""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
|
||||
def on_epoch_end(self, trainer): # pylint: disable=unused-argument
|
||||
# stop updating mean and var
|
||||
# TODO: do the same for F0
|
||||
|
@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
self.train_disc = ( # pylint: disable=attribute-defined-outside-init
|
||||
trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
)
|
||||
|
|
|
@ -396,6 +396,7 @@ class ForwardTTS(BaseTTS):
|
|||
- g: :math:`(B, C)`
|
||||
"""
|
||||
if hasattr(self, "emb_g"):
|
||||
g = g.type(torch.LongTensor)
|
||||
g = self.emb_g(g) # [B, C, 1]
|
||||
if g is not None:
|
||||
g = g.unsqueeze(-1)
|
||||
|
@ -683,9 +684,10 @@ class ForwardTTS(BaseTTS):
|
|||
# encoder pass
|
||||
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
|
||||
# pitch predictor pass
|
||||
o_pitch = None
|
||||
if self.args.use_pitch:
|
||||
|
|
|
@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
|
|||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
init_stream_support()
|
||||
|
||||
def load_audio(audiopath, sr=22050):
|
||||
"""
|
||||
|
@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
|
|||
Args:
|
||||
gpt_batch_size (int): The size of the auto-regressive batch.
|
||||
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
||||
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
|
||||
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
||||
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
|
||||
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
||||
|
||||
For GPT model:
|
||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
|
@ -231,12 +233,13 @@ class XttsArgs(Coqpit):
|
|||
|
||||
gpt_batch_size: int = 1
|
||||
enable_redaction: bool = False
|
||||
lazy_load: bool = True
|
||||
kv_cache: bool = True
|
||||
gpt_checkpoint: str = None
|
||||
clvp_checkpoint: str = None
|
||||
decoder_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
use_hifigan: bool = True
|
||||
use_ne_hifigan: bool = False
|
||||
|
||||
# XTTS GPT Encoder params
|
||||
tokenizer_file: str = ""
|
||||
|
@ -266,6 +269,15 @@ class XttsArgs(Coqpit):
|
|||
diff_layer_drop: int = 0
|
||||
diff_unconditioned_percentage: int = 0
|
||||
|
||||
# HifiGAN Decoder params
|
||||
input_sample_rate: int = 22050
|
||||
output_sample_rate: int = 24000
|
||||
output_hop_length: int = 256
|
||||
ar_mel_length_compression: int = 1024
|
||||
decoder_input_dim: int = 1024
|
||||
d_vector_dim: int = 512
|
||||
cond_d_vector_in_each_upsampling_layer: bool = True
|
||||
|
||||
# constants
|
||||
duration_const: int = 102400
|
||||
|
||||
|
@ -285,7 +297,6 @@ class Xtts(BaseTTS):
|
|||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.lazy_load = self.args.lazy_load
|
||||
self.mel_stats_path = None
|
||||
self.config = config
|
||||
self.gpt_checkpoint = self.args.gpt_checkpoint
|
||||
|
@ -295,14 +306,13 @@ class Xtts(BaseTTS):
|
|||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
self.gpt = None
|
||||
self.diffusion_decoder = None
|
||||
self.init_models()
|
||||
self.register_buffer("mel_stats", torch.ones(80))
|
||||
|
||||
def init_models(self):
|
||||
"""Initialize the models. We do it here since we need to load the tokenizer first."""
|
||||
if self.tokenizer.tokenizer is not None:
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens()
|
||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
||||
|
||||
|
@ -322,40 +332,50 @@ class Xtts(BaseTTS):
|
|||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
)
|
||||
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
|
||||
self.vocoder = UnivNetGenerator()
|
||||
if self.args.use_hifigan:
|
||||
self.hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
if self.args.use_ne_hifigan:
|
||||
self.ne_hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
self.vocoder = UnivNetGenerator()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@contextmanager
|
||||
def lazy_load_model(self, model):
|
||||
"""Context to load a model on demand.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
"""
|
||||
if self.lazy_load:
|
||||
yield model
|
||||
else:
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
|
@ -370,6 +390,7 @@ class Xtts(BaseTTS):
|
|||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_diffusion_cond_latents(
|
||||
self,
|
||||
audio_path,
|
||||
|
@ -389,20 +410,33 @@ class Xtts(BaseTTS):
|
|||
)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(
|
||||
self,
|
||||
audio_path
|
||||
):
|
||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
return speaker_embedding
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
gpt_cond_len=3,
|
||||
):
|
||||
):
|
||||
speaker_embedding = None
|
||||
diffusion_cond_latents = None
|
||||
if self.args.use_hifigan:
|
||||
speaker_embedding = self.get_speaker_embedding(audio_path)
|
||||
else:
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(
|
||||
audio_path,
|
||||
)
|
||||
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
|
||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
@ -447,10 +481,10 @@ class Xtts(BaseTTS):
|
|||
"decoder_sampler": config.decoder_sampler,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.inference(text, ref_audio_path, language, **settings)
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
self,
|
||||
text,
|
||||
ref_audio_path,
|
||||
|
@ -469,6 +503,7 @@ class Xtts(BaseTTS):
|
|||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -517,6 +552,9 @@ class Xtts(BaseTTS):
|
|||
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
Defaults to 1.0.
|
||||
|
||||
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
|
||||
Defaults to hifigan
|
||||
|
||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
||||
|
@ -525,6 +563,56 @@ class Xtts(BaseTTS):
|
|||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=temperature,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
decoder_iterations=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
decoder_sampler=decoder_sampler,
|
||||
decoder=decoder,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -532,74 +620,160 @@ class Xtts(BaseTTS):
|
|||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=decoder_sampler,
|
||||
)
|
||||
if not self.args.use_hifigan:
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=decoder_sampler,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
gpt_codes = gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
gpt_codes = self.gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = self.gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
else:
|
||||
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||
mel = do_spectrogram_diffusion(
|
||||
diffusion,
|
||||
self.diffusion_decoder,
|
||||
diffuser,
|
||||
gpt_latents,
|
||||
diffusion_conditioning,
|
||||
temperature=diffusion_temperature,
|
||||
)
|
||||
with self.lazy_load_model(self.vocoder) as vocoder:
|
||||
wav = vocoder.inference(mel)
|
||||
wav = self.vocoder.inference(mel)
|
||||
|
||||
return {"wav": wav.cpu().numpy().squeeze()}
|
||||
|
||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||
"""Handle chunk formatting in streaming mode"""
|
||||
wav_chunk = wav_gen[:-overlap_len]
|
||||
if wav_gen_prev is not None:
|
||||
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
|
||||
if wav_overlap is not None:
|
||||
crossfade_wav = wav_chunk[:overlap_len]
|
||||
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
|
||||
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
|
||||
wav_chunk[:overlap_len] += crossfade_wav
|
||||
wav_overlap = wav_gen[-overlap_len:]
|
||||
wav_gen_prev = wav_gen
|
||||
return wav_chunk, wav_gen_prev, wav_overlap
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_stream(
|
||||
self,
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
# Streaming
|
||||
stream_chunk_size=20,
|
||||
overlap_wav_len=1024,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
fake_inputs = self.gpt.compute_embeddings(
|
||||
gpt_cond_latent.to(self.device),
|
||||
text_tokens,
|
||||
)
|
||||
gpt_generator = self.gpt.get_generator(
|
||||
fake_inputs=fake_inputs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
length_penalty=float(length_penalty),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
last_tokens = []
|
||||
all_latents = []
|
||||
wav_gen_prev = None
|
||||
wav_overlap = None
|
||||
is_end = False
|
||||
|
||||
while not is_end:
|
||||
try:
|
||||
x, latent = next(gpt_generator)
|
||||
last_tokens += [x]
|
||||
all_latents += [latent]
|
||||
except StopIteration:
|
||||
is_end = True
|
||||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
else:
|
||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
)
|
||||
last_tokens = []
|
||||
yield wav_chunk
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
||||
|
@ -616,7 +790,14 @@ class Xtts(BaseTTS):
|
|||
super().eval()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
vocab_path=None,
|
||||
eval=True,
|
||||
strict=True,
|
||||
use_deepspeed=False,
|
||||
):
|
||||
"""
|
||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||
|
@ -626,7 +807,7 @@ class Xtts(BaseTTS):
|
|||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
@ -636,19 +817,34 @@ class Xtts(BaseTTS):
|
|||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
|
||||
if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
|
||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
|
||||
if os.path.exists(vocab_path):
|
||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
||||
|
||||
self.init_models()
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
|
||||
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
|
||||
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
|
||||
try:
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
except:
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||
self.gpt.eval()
|
||||
self.diffusion_decoder.eval()
|
||||
self.vocoder.eval()
|
||||
|
||||
def train_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
|
@ -427,16 +428,24 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False,
|
|||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out = None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm)
|
||||
|
||||
|
||||
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from io import BytesIO
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import librosa
|
||||
|
@ -693,20 +694,27 @@ class AudioProcessor(object):
|
|||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out = None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
if self.do_rms_norm:
|
||||
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
|
||||
|
||||
def get_duration(self, filename: str) -> float:
|
||||
"""Get the duration of a wav file using Librosa.
|
||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -293,8 +294,9 @@ class ModelManager(object):
|
|||
# get model from models.json
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
||||
model_item = self.set_model_url(model_item)
|
||||
return model_item, model_full_name, model
|
||||
return model_item, model_full_name, model, md5hash
|
||||
|
||||
def ask_tos(self, model_full_path):
|
||||
"""Ask the user to agree to the terms of service"""
|
||||
|
@ -320,6 +322,44 @@ class ModelManager(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||
config_local = json.load(f)
|
||||
remote_url = None
|
||||
for url in model_item["hf_url"]:
|
||||
if "config.json" in url:
|
||||
remote_url = url
|
||||
break
|
||||
|
||||
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
|
@ -334,37 +374,39 @@ class ModelManager(object):
|
|||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
if md5sum is not None:
|
||||
md5sum_file = os.path.join(output_path, "hash.md5")
|
||||
if os.path.isfile(md5sum_file):
|
||||
with open(md5sum_file, mode="r") as f:
|
||||
if not f.read() == md5sum:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name
|
||||
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
|
@ -498,13 +540,16 @@ class ModelManager(object):
|
|||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist()[1:]:
|
||||
for file_path in z.namelist():
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove the extracted folder
|
||||
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
||||
if os.path.isfile(src_path):
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove redundant (hidden or not) folders
|
||||
for file_path in z.namelist():
|
||||
if os.path.isdir(os.path.join(output_folder, file_path)):
|
||||
rmtree(os.path.join(output_folder, file_path))
|
||||
|
||||
@staticmethod
|
||||
def _download_tar_file(file_url, output_folder, progress_bar):
|
||||
|
|
|
@ -235,19 +235,20 @@ class Synthesizer(nn.Module):
|
|||
"""
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str) -> None:
|
||||
def save_wav(self, wav: List[int], path: str, pipe_out = None) -> None:
|
||||
"""Save the waveform as a file.
|
||||
|
||||
Args:
|
||||
wav (List[int]): waveform as a list of values.
|
||||
path (str): output path to save the waveform.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
# if tensor convert to numpy
|
||||
if torch.is_tensor(wav):
|
||||
wav = wav.cpu().numpy()
|
||||
if isinstance(wav, list):
|
||||
wav = np.array(wav)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
|
||||
|
||||
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
|
||||
|
@ -299,11 +300,7 @@ class Synthesizer(nn.Module):
|
|||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
# handle Neon models with single speaker.
|
||||
if len(self.tts_model.speaker_manager.name_to_id) == 1:
|
||||
speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
|
||||
|
||||
elif speaker_name and isinstance(speaker_name, str):
|
||||
if speaker_name and isinstance(speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the average speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
|
@ -313,7 +310,9 @@ class Synthesizer(nn.Module):
|
|||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
|
||||
|
||||
# handle Neon models with single speaker.
|
||||
elif len(self.tts_model.speaker_manager.name_to_id) == 1:
|
||||
speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Looks like you are using a multi-speaker model. "
|
||||
|
|
|
@ -116,12 +116,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
|
|
@ -17,19 +17,20 @@ Let's assume you created the audio clips and their transcription. You can collec
|
|||
...
|
||||
```
|
||||
|
||||
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text.
|
||||
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimitered by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text.
|
||||
|
||||
We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc.
|
||||
|
||||
```
|
||||
# metadata.txt
|
||||
|
||||
audio1|This is my sentence.
|
||||
audio2|This is maybe my sentence.
|
||||
audio3|This is certainly my sentence.
|
||||
audio4|Let this be your sentence.
|
||||
audio1|This is my sentence.|This is my sentence.
|
||||
audio2|1469 and 1470|fourteen sixty-nine and fourteen seventy
|
||||
audio3|It'll be $16 sir.|It'll be sixteen dollars sir.
|
||||
...
|
||||
```
|
||||
*If you don't have normalized transcriptions, you can use the same transcription for both columns. If it's your case, we recommend to use normalization later in the pipeline, either in the text cleaner or in the phonemizer.*
|
||||
|
||||
|
||||
In the end, we have the following folder structure
|
||||
```
|
||||
|
|
|
@ -41,7 +41,7 @@
|
|||
6. Optionally, define `MyModelArgs`.
|
||||
|
||||
`MyModelArgs` is a 👨✈️Coqpit class that sets all the class arguments of the `MyModel`. `MyModelArgs` must have
|
||||
all the fields neccessary to instantiate the `MyModel`. However, for training, you need to pass `MyModelConfig` to
|
||||
all the fields necessary to instantiate the `MyModel`. However, for training, you need to pass `MyModelConfig` to
|
||||
the model.
|
||||
|
||||
7. Test `MyModel`.
|
||||
|
|
|
@ -114,18 +114,24 @@ tts-server --model_name "<type>/<language>/<dataset>/<model_name>" \
|
|||
You can run a multi-speaker and multi-lingual model in Python as
|
||||
|
||||
```python
|
||||
import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
# List available 🐸TTS models and choose the first one
|
||||
model_name = TTS().list_models()[0]
|
||||
# Get device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# List available 🐸TTS models
|
||||
print(TTS().list_models())
|
||||
|
||||
# Init TTS
|
||||
tts = TTS(model_name)
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
|
||||
|
||||
# Run TTS
|
||||
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
|
||||
# Text to speech with a numpy output
|
||||
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
||||
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
||||
# Text to speech list of amplitude values as output
|
||||
wav = tts.tts(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en")
|
||||
# Text to speech to a file
|
||||
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
|
||||
tts.tts_to_file(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||
```
|
||||
|
||||
#### Here is an example for a single speaker model.
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# Trainer API
|
||||
|
||||
We made the trainer a seprate project on https://github.com/coqui-ai/Trainer
|
||||
We made the trainer a separate project on https://github.com/coqui-ai/Trainer
|
||||
|
|
|
@ -12,7 +12,7 @@ Currently we provide the following pre-configured architectures:
|
|||
|
||||
- **FastPitch:**
|
||||
|
||||
It uses the same FastSpeech architecture that is conditioned on fundemental frequency (f0) contours with the
|
||||
It uses the same FastSpeech architecture that is conditioned on fundamental frequency (f0) contours with the
|
||||
promise of more expressive speech.
|
||||
|
||||
- **SpeedySpeech:**
|
||||
|
|
|
@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
|
|||
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
|
||||
You can also mail us at info@coqui.ai.
|
||||
|
||||
Using 🐸TTS API:
|
||||
### Inference
|
||||
#### 🐸TTS API
|
||||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
|
@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
|
|||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en")
|
||||
|
||||
# generate speech by cloning a voice using custom settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en",
|
||||
decoder_iterations=30)
|
||||
```
|
||||
|
||||
Using 🐸TTS Command line:
|
||||
#### 🐸TTS Command line
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
|
||||
|
@ -58,25 +52,85 @@ Using 🐸TTS Command line:
|
|||
--use_cuda true
|
||||
```
|
||||
|
||||
Using model directly:
|
||||
#### model directly
|
||||
|
||||
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
|
||||
|
||||
```console
|
||||
pip install deepspeed==0.8.3
|
||||
```
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
print("Loading model...")
|
||||
config = XttsConfig()
|
||||
config.load_json("/path/to/xtts/config.json")
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
|
||||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
|
||||
print("Inference...")
|
||||
out = model.inference(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=0.7, # Add custom parameters here
|
||||
)
|
||||
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||
```
|
||||
|
||||
|
||||
#### streaming inference
|
||||
|
||||
Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
|
||||
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
|
||||
|
||||
|
||||
```python
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
print("Loading model...")
|
||||
config = XttsConfig()
|
||||
config.load_json("/path/to/xtts/config.json")
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
|
||||
model.cuda()
|
||||
|
||||
outputs = model.synthesize(
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
|
||||
print("Inference...")
|
||||
t0 = time.time()
|
||||
chunks = model.inference_stream(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
config,
|
||||
speaker_wav="/data/TTS-public/_refclips/3.wav",
|
||||
gpt_cond_len=3,
|
||||
language="en",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
)
|
||||
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
print(f"Time to first chunck: {time.time() - t0}")
|
||||
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
|
||||
wav_chuncks.append(chunk)
|
||||
wav = torch.cat(wav_chuncks, dim=0)
|
||||
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import importlib\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm as tqdm\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import soundfile as sf\n",
|
||||
"import pickle\n",
|
||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
|
@ -33,8 +33,8 @@
|
|||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
||||
"# Configure CUDA visibility\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -43,6 +43,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to create directories and file names\n",
|
||||
"def set_filename(wav_path, out_path):\n",
|
||||
" wav_file = os.path.basename(wav_path)\n",
|
||||
" file_name = wav_file.split('.')[0]\n",
|
||||
|
@ -61,6 +62,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Paths and configurations\n",
|
||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
|
@ -73,12 +75,15 @@
|
|||
"QUANTIZE_BIT = None\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"# Check CUDA availability\n",
|
||||
"use_cuda = torch.cuda.is_available()\n",
|
||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"# Load the configuration\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
||||
"print(C['r'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -87,14 +92,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(C['r'])\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"# If the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C and C['characters']:\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"# Load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speaker\n",
|
||||
"# TODO: multiple speakers\n",
|
||||
"model = setup_model(C)\n",
|
||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||
]
|
||||
|
@ -105,11 +109,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the preprocessor based on the dataset\n",
|
||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
||||
"dataset = TTSDataset(\n",
|
||||
" checkpoint[\"config\"][\"r\"],\n",
|
||||
" C,\n",
|
||||
" C.text_cleaner,\n",
|
||||
" False,\n",
|
||||
" ap,\n",
|
||||
|
@ -124,6 +129,24 @@
|
|||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Create log file\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"log_file = open(log_file_path, \"w\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -137,83 +160,85 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"# Start processing with a progress bar\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in tqdm(loader):\n",
|
||||
" # setup input data\n",
|
||||
" text_input = data[0]\n",
|
||||
" text_lengths = data[1]\n",
|
||||
" linear_input = data[3]\n",
|
||||
" mel_input = data[4]\n",
|
||||
" mel_lengths = data[5]\n",
|
||||
" stop_targets = data[6]\n",
|
||||
" item_idx = data[7]\n",
|
||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||
" try:\n",
|
||||
" # setup input data\n",
|
||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" \n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
"\n",
|
||||
" # compute mel specs from linear spec if model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
" # compute mel specs from linear spec if the model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" # for wavernn\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
|
||||
" \n",
|
||||
" # for pwgan\n",
|
||||
" with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" print(np.mean(losses))\n",
|
||||
" print(np.mean(postnet_losses))"
|
||||
" except Exception as e:\n",
|
||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||
"\n",
|
||||
" # Calculate and log mean losses\n",
|
||||
" mean_loss = np.mean(losses)\n",
|
||||
" mean_postnet_loss = np.mean(postnet_losses)\n",
|
||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||
"\n",
|
||||
"# Close the log file\n",
|
||||
"log_file.close()\n",
|
||||
"\n",
|
||||
"# For wavernn\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||
"\n",
|
||||
"# For pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"# Print mean losses\n",
|
||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -100,7 +100,7 @@
|
|||
" wav_file = item[\"audio_file\"].strip()\n",
|
||||
" wav_files.append(wav_file)\n",
|
||||
" if not os.path.exists(wav_file):\n",
|
||||
" print(waf_path)"
|
||||
" print(wav_file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
[build-system]
|
||||
requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"]
|
||||
requires = [
|
||||
"setuptools",
|
||||
"wheel",
|
||||
"cython~=0.29.30",
|
||||
"numpy>=1.22.0",
|
||||
"packaging",
|
||||
]
|
||||
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
|
@ -7,25 +13,6 @@ max-line-length=120
|
|||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py39']
|
||||
exclude = '''
|
||||
|
||||
(
|
||||
/(
|
||||
\.eggs # exclude a few common directories in the
|
||||
| \.git # root of the project
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
| foo.py # also separately exclude a file named foo.py in
|
||||
# the root of the project
|
||||
)
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
line_length = 120
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
# japanese g2p deps
|
||||
mecab-python3==1.0.6
|
||||
unidic-lite==1.0.8
|
||||
cutlet
|
||||
|
|
|
@ -5,26 +5,27 @@ cython==0.29.30
|
|||
scipy>=1.11.2
|
||||
torch>=1.7
|
||||
torchaudio
|
||||
soundfile
|
||||
librosa==0.10.0.*
|
||||
soundfile==0.12.*
|
||||
librosa==0.10.*
|
||||
scikit-learn==1.3.0
|
||||
numba==0.55.1;python_version<"3.9"
|
||||
numba==0.57.0;python_version>="3.9"
|
||||
inflect==5.6.0
|
||||
tqdm
|
||||
anyascii
|
||||
pyyaml
|
||||
inflect==5.6.*
|
||||
tqdm==4.64.*
|
||||
anyascii==0.3.*
|
||||
pyyaml==6.*
|
||||
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||
aiohttp
|
||||
packaging
|
||||
aiohttp==3.8.*
|
||||
packaging==23.1
|
||||
# deps for examples
|
||||
flask
|
||||
flask==2.*
|
||||
# deps for inference
|
||||
pysbd
|
||||
pysbd==0.3.4
|
||||
# deps for notebooks
|
||||
umap-learn==0.5.1
|
||||
pandas
|
||||
umap-learn==0.5.*
|
||||
pandas>=1.4,<2.0
|
||||
# deps for training
|
||||
matplotlib
|
||||
matplotlib==3.7.*
|
||||
# coqui stack
|
||||
trainer
|
||||
# config management
|
||||
|
@ -39,14 +40,14 @@ jamo
|
|||
nltk
|
||||
g2pkk>=0.1.1
|
||||
# deps for bangla
|
||||
bangla==0.0.2
|
||||
bangla
|
||||
bnnumerizer
|
||||
bnunicodenormalizer==0.1.1
|
||||
bnunicodenormalizer
|
||||
#deps for tortoise
|
||||
k_diffusion
|
||||
einops
|
||||
transformers
|
||||
einops==0.6.*
|
||||
transformers==4.33.*
|
||||
#deps for bark
|
||||
encodec
|
||||
encodec==0.1.*
|
||||
# deps for XTTS
|
||||
unidecode
|
||||
unidecode==1.3.*
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def replace_between_markers(content, marker: str, replacement: str) -> str:
|
||||
start_marker = f"<!-- begin-{marker} -->\n\n"
|
||||
end_marker = f"\n\n<!-- end-{marker} -->\n"
|
||||
start_index = content.index(start_marker) + len(start_marker)
|
||||
end_index = content.index(end_marker)
|
||||
content = content[:start_index] + replacement + content[end_index:]
|
||||
return content
|
||||
|
||||
|
||||
def sync_readme():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--check", action="store_true", default=False)
|
||||
args = ap.parse_args()
|
||||
readme_path = Path(__file__).parent.parent / "README.md"
|
||||
orig_content = readme_path.read_text()
|
||||
from TTS.bin.synthesize import description
|
||||
|
||||
new_content = replace_between_markers(orig_content, "tts-readme", description.strip())
|
||||
if args.check:
|
||||
if orig_content != new_content:
|
||||
print("README.md is out of sync; please edit TTS/bin/TTS_README.md and run scripts/sync_readme.py")
|
||||
exit(42)
|
||||
readme_path.write_text(new_content)
|
||||
print("Updated README.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sync_readme()
|
|
@ -13,3 +13,16 @@ def test_synthesize():
|
|||
'--text "This is it" '
|
||||
f'--out_path "{output_path}"'
|
||||
)
|
||||
|
||||
# 🐸 Coqui studio model with speed arg.
|
||||
run_cli(
|
||||
'tts --model_name "coqui_studio/en/Torcull Diarmuid/coqui_studio" '
|
||||
'--text "This is it but slow" --speed 0.1'
|
||||
f'--out_path "{output_path}"'
|
||||
)
|
||||
|
||||
# test pipe_out command
|
||||
run_cli(
|
||||
'tts --text "test." --pipe_out '
|
||||
f'--out_path "{output_path}" | aplay'
|
||||
)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_readme_up_to_date():
|
||||
root = Path(__file__).parent.parent.parent
|
||||
sync_readme = root / "scripts" / "sync_readme.py"
|
||||
subprocess.check_call([sys.executable, str(sync_readme), "--check"], cwd=root)
|
|
@ -3,13 +3,20 @@ import glob
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
MODELS_WITH_SEP_TESTS = ["bark", "xtts"]
|
||||
MODELS_WITH_SEP_TESTS = [
|
||||
"tts_models/multilingual/multi-dataset/bark",
|
||||
"tts_models/en/multi-dataset/tortoise-v2",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v1",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v1.1",
|
||||
]
|
||||
|
||||
|
||||
def run_models(offset=0, step=1):
|
||||
|
@ -17,7 +24,8 @@ def run_models(offset=0, step=1):
|
|||
print(" > Run synthesizer with all the models.")
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
|
||||
model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS]
|
||||
model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS]
|
||||
print("Model names:", model_names)
|
||||
for model_name in model_names[offset::step]:
|
||||
print(f"\n > Run - {model_name}")
|
||||
model_path, _, _ = manager.download_model(model_name)
|
||||
|
@ -67,23 +75,85 @@ def run_models(offset=0, step=1):
|
|||
|
||||
|
||||
def test_xtts():
|
||||
"""XTTS is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
|
||||
|
||||
def test_xtts_streaming():
|
||||
"""Testing the new inference_stream method"""
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir=model_path)
|
||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||
|
||||
print("Inference...")
|
||||
chunks = model.inference_stream(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
assert chunk.shape[-1] > 5000
|
||||
wav_chuncks.append(chunk)
|
||||
assert len(wav_chuncks) > 1
|
||||
|
||||
|
||||
def test_tortoise():
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||
)
|
||||
|
||||
|
||||
def test_bark():
|
||||
"""Bark is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||
)
|
||||
|
||||
|
||||
def test_voice_conversion():
|
||||
|
|
Loading…
Reference in New Issue