mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'coqui-ai:dev' into progress_bar
This commit is contained in:
commit
8a2c4b778a
|
@ -10,7 +10,7 @@ jobs:
|
||||||
build-sdist:
|
build-sdist:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Verify tag matches version
|
- name: Verify tag matches version
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
|
@ -38,7 +38,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- uses: actions/setup-python@v2
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
|
@ -42,6 +42,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
python3 setup.py egg_info
|
python3 setup.py egg_info
|
||||||
# - name: Lint check
|
- name: Style check
|
||||||
# run: |
|
run: make style
|
||||||
# make lint
|
|
||||||
|
|
|
@ -169,3 +169,4 @@ wandb
|
||||||
depot/*
|
depot/*
|
||||||
coqui_recipes/*
|
coqui_recipes/*
|
||||||
local_scripts/*
|
local_scripts/*
|
||||||
|
coqui_demos/*
|
15
README.md
15
README.md
|
@ -1,5 +1,8 @@
|
||||||
|
|
||||||
## 🐸Coqui.ai News
|
## 🐸Coqui.ai News
|
||||||
|
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
|
||||||
|
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech).
|
||||||
|
- 📣 ⓍTTS can now stream with <200ms latency.
|
||||||
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
|
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
|
||||||
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
|
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
|
||||||
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
|
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
|
||||||
|
@ -25,7 +28,7 @@
|
||||||
📚 Utilities for dataset analysis and curation.
|
📚 Utilities for dataset analysis and curation.
|
||||||
______________________________________________________________________
|
______________________________________________________________________
|
||||||
|
|
||||||
[](https://discord.gg/5eXr5seRrv)
|
[](https://discord.gg/5eXr5seRrv)
|
||||||
[](https://opensource.org/licenses/MPL-2.0)
|
[](https://opensource.org/licenses/MPL-2.0)
|
||||||
[](https://badge.fury.io/py/TTS)
|
[](https://badge.fury.io/py/TTS)
|
||||||
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
||||||
|
@ -202,7 +205,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(TTS().list_models())
|
print(TTS().list_models())
|
||||||
|
|
||||||
# Init TTS
|
# Init TTS
|
||||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
||||||
|
|
||||||
# Run TTS
|
# Run TTS
|
||||||
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
||||||
|
@ -264,19 +267,13 @@ models = TTS(cs_api_model="XTTS").list_models()
|
||||||
# Init TTS with the target studio speaker
|
# Init TTS with the target studio speaker
|
||||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
|
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
|
||||||
# Run TTS
|
# Run TTS
|
||||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
|
tts.tts_to_file(text="This is a test.", language="en", file_path=OUTPUT_PATH)
|
||||||
|
|
||||||
# V1 model
|
# V1 model
|
||||||
models = TTS(cs_api_model="V1").list_models()
|
models = TTS(cs_api_model="V1").list_models()
|
||||||
# Run TTS with emotion and speed control
|
# Run TTS with emotion and speed control
|
||||||
# Emotion control only works with V1 model
|
# Emotion control only works with V1 model
|
||||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)
|
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)
|
||||||
|
|
||||||
# XTTS-multilingual
|
|
||||||
models = TTS(cs_api_model="XTTS-multilingual").list_models()
|
|
||||||
# Run TTS with emotion and speed control
|
|
||||||
# Emotion control only works with V1 model
|
|
||||||
tts.tts_to_file(text="Das ist ein Test.", file_path=OUTPUT_PATH, language="de", speed=1.0)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
|
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
|
||||||
|
|
|
@ -2,15 +2,17 @@
|
||||||
"tts_models": {
|
"tts_models": {
|
||||||
"multilingual": {
|
"multilingual": {
|
||||||
"multi-dataset": {
|
"multi-dataset": {
|
||||||
"xtts_v1": {
|
"xtts_v2": {
|
||||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
"description": "XTTS-v2.0.2 by Coqui with 16 languages.",
|
||||||
"hf_url": [
|
"hf_url": [
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
||||||
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
||||||
],
|
],
|
||||||
|
"model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c",
|
||||||
"default_vocoder": null,
|
"default_vocoder": null,
|
||||||
"commit": "e5140314",
|
"commit": "480a6cdf7",
|
||||||
"license": "CPML",
|
"license": "CPML",
|
||||||
"contact": "info@coqui.ai",
|
"contact": "info@coqui.ai",
|
||||||
"tos_required": true
|
"tos_required": true
|
||||||
|
@ -18,12 +20,12 @@
|
||||||
"xtts_v1.1": {
|
"xtts_v1.1": {
|
||||||
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
|
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
|
||||||
"hf_url": [
|
"hf_url": [
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5"
|
||||||
],
|
],
|
||||||
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
|
"model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59",
|
||||||
"default_vocoder": null,
|
"default_vocoder": null,
|
||||||
"commit": "82910a63",
|
"commit": "82910a63",
|
||||||
"license": "CPML",
|
"license": "CPML",
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.19.1
|
0.20.6
|
||||||
|
|
|
@ -60,7 +60,7 @@ class TTS(nn.Module):
|
||||||
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
||||||
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
||||||
cs_api_model (str, optional): Name of the model to use for the Coqui Studio API. Available models are
|
cs_api_model (str, optional): Name of the model to use for the Coqui Studio API. Available models are
|
||||||
"XTTS", "XTTS-multilingual", "V1". You can also use `TTS.cs_api.CS_API" for more control.
|
"XTTS", "V1". You can also use `TTS.cs_api.CS_API" for more control.
|
||||||
Defaults to "XTTS".
|
Defaults to "XTTS".
|
||||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
@ -275,7 +275,7 @@ class TTS(nn.Module):
|
||||||
speaker_name (str, optional):
|
speaker_name (str, optional):
|
||||||
Speaker name from Coqui Studio. Defaults to None.
|
Speaker name from Coqui Studio. Defaults to None.
|
||||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||||
supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
supported by `XTTS` model.
|
||||||
emotion (str, optional):
|
emotion (str, optional):
|
||||||
Emotion of the speaker. One of "Neutral", "Happy", "Sad", "Angry", "Dull". Emotions are only available
|
Emotion of the speaker. One of "Neutral", "Happy", "Sad", "Angry", "Dull". Emotions are only available
|
||||||
with "V1" model. Defaults to None.
|
with "V1" model. Defaults to None.
|
||||||
|
@ -321,7 +321,7 @@ class TTS(nn.Module):
|
||||||
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
||||||
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
||||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||||
supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
supported by `XTTS` model.
|
||||||
speaker_wav (str, optional):
|
speaker_wav (str, optional):
|
||||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
|
@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import quantize
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -159,7 +160,7 @@ def inference(
|
||||||
|
|
||||||
|
|
||||||
def extract_spectrograms(
|
def extract_spectrograms(
|
||||||
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
||||||
):
|
):
|
||||||
model.eval()
|
model.eval()
|
||||||
export_metadata = []
|
export_metadata = []
|
||||||
|
@ -196,8 +197,8 @@ def extract_spectrograms(
|
||||||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||||
|
|
||||||
# quantize and save wav
|
# quantize and save wav
|
||||||
if quantized_wav:
|
if quantize_bits > 0:
|
||||||
wavq = ap.quantize(wav)
|
wavq = quantize(wav, quantize_bits)
|
||||||
np.save(wavq_path, wavq)
|
np.save(wavq_path, wavq)
|
||||||
|
|
||||||
# save TTS mel
|
# save TTS mel
|
||||||
|
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model,
|
model,
|
||||||
ap,
|
ap,
|
||||||
args.output_path,
|
args.output_path,
|
||||||
quantized_wav=args.quantized,
|
quantize_bits=args.quantize_bits,
|
||||||
save_audio=args.save_audio,
|
save_audio=args.save_audio,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
metada_name="metada.txt",
|
metada_name="metada.txt",
|
||||||
|
@ -277,7 +278,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cs_model",
|
"--cs_model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the 🐸Coqui Studio model. Available models are `XTTS`, `XTTS-multilingual`, `V1`.",
|
help="Name of the 🐸Coqui Studio model. Available models are `XTTS`, `V1`.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--emotion",
|
"--emotion",
|
||||||
|
@ -238,7 +238,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--language",
|
"--language",
|
||||||
type=str,
|
type=str,
|
||||||
help="Language to condition the model with. Only available for 🐸Coqui Studio `XTTS-multilingual` model.",
|
help="Language to condition the model with. Only available for 🐸Coqui Studio `XTTS` model.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -427,7 +427,9 @@ def main():
|
||||||
tts_path = model_path
|
tts_path = model_path
|
||||||
tts_config_path = config_path
|
tts_config_path = config_path
|
||||||
if "default_vocoder" in model_item:
|
if "default_vocoder" in model_item:
|
||||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
args.vocoder_name = (
|
||||||
|
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||||
|
)
|
||||||
|
|
||||||
# voice conversion model
|
# voice conversion model
|
||||||
if model_item["model_type"] == "voice_conversion_models":
|
if model_item["model_type"] == "voice_conversion_models":
|
||||||
|
|
|
@ -8,17 +8,17 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||||
from trainer.torch import NoamLR
|
from trainer.torch import NoamLR
|
||||||
from trainer.trainer_utils import get_optimizer
|
from trainer.trainer_utils import get_optimizer
|
||||||
|
|
||||||
from TTS.encoder.dataset import EncoderDataset
|
from TTS.encoder.dataset import EncoderDataset
|
||||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
from TTS.encoder.utils.training import init_training
|
from TTS.encoder.utils.training import init_training
|
||||||
from TTS.encoder.utils.visual import plot_embeddings
|
from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from TTS.utils.io import copy_model_files
|
|
||||||
from TTS.utils.samplers import PerfectBatchSampler
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
|
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
||||||
|
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
|
save_checkpoint(
|
||||||
|
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
# save the best checkpoint
|
# save the best checkpoint
|
||||||
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
|
best_loss = save_best_model(
|
||||||
|
eval_loss,
|
||||||
|
best_loss,
|
||||||
|
c,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
None,
|
||||||
|
global_step,
|
||||||
|
epoch,
|
||||||
|
OUT_PATH,
|
||||||
|
criterion=criterion.state_dict(),
|
||||||
|
)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
return best_loss, global_step
|
return best_loss, global_step
|
||||||
|
@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||||
c.map_classid_to_classname = map_classid_to_classname
|
c.map_classid_to_classname = map_classid_to_classname
|
||||||
copy_model_files(c, OUT_PATH)
|
copy_model_files(c, OUT_PATH, new_fields={})
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
criterion, args.restore_step = model.load_checkpoint(
|
criterion, args.restore_step = model.load_checkpoint(
|
||||||
|
|
|
@ -43,7 +43,7 @@ class CS_API:
|
||||||
Args:
|
Args:
|
||||||
api_token (str): 🐸Coqui Studio API token. If not provided, it will be read from the environment variable
|
api_token (str): 🐸Coqui Studio API token. If not provided, it will be read from the environment variable
|
||||||
`COQUI_STUDIO_TOKEN`.
|
`COQUI_STUDIO_TOKEN`.
|
||||||
model (str): 🐸Coqui Studio model. It can be either `V1`, `XTTS`, or `XTTS-multilang`. Default is `XTTS`.
|
model (str): 🐸Coqui Studio model. It can be either `V1`, `XTTS`. Default is `XTTS`.
|
||||||
|
|
||||||
|
|
||||||
Example listing all available speakers:
|
Example listing all available speakers:
|
||||||
|
@ -65,7 +65,7 @@ class CS_API:
|
||||||
|
|
||||||
Example with multi-language model:
|
Example with multi-language model:
|
||||||
>>> from TTS.api import CS_API
|
>>> from TTS.api import CS_API
|
||||||
>>> tts = CS_API(model="XTTS-multilang")
|
>>> tts = CS_API(model="XTTS")
|
||||||
>>> wav, sr = api.tts("Hello world", speaker_name=tts.speakers[0].name, language="en")
|
>>> wav, sr = api.tts("Hello world", speaker_name=tts.speakers[0].name, language="en")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -78,16 +78,11 @@ class CS_API:
|
||||||
"XTTS": {
|
"XTTS": {
|
||||||
"list_speakers": "https://app.coqui.ai/api/v2/speakers",
|
"list_speakers": "https://app.coqui.ai/api/v2/speakers",
|
||||||
"synthesize": "https://app.coqui.ai/api/v2/samples/xtts/render/",
|
"synthesize": "https://app.coqui.ai/api/v2/samples/xtts/render/",
|
||||||
"list_voices": "https://app.coqui.ai/api/v2/voices/xtts/",
|
"list_voices": "https://app.coqui.ai/api/v2/voices/xtts",
|
||||||
},
|
|
||||||
"XTTS-multilang": {
|
|
||||||
"list_speakers": "https://app.coqui.ai/api/v2/speakers",
|
|
||||||
"synthesize": "https://app.coqui.ai/api/v2/samples/multilingual/render/",
|
|
||||||
"list_voices": "https://app.coqui.ai/api/v2/voices/xtts/",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl"]
|
SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"]
|
||||||
|
|
||||||
def __init__(self, api_token=None, model="XTTS"):
|
def __init__(self, api_token=None, model="XTTS"):
|
||||||
self.api_token = api_token
|
self.api_token = api_token
|
||||||
|
@ -139,7 +134,7 @@ class CS_API:
|
||||||
self._check_token()
|
self._check_token()
|
||||||
conn = http.client.HTTPSConnection("app.coqui.ai")
|
conn = http.client.HTTPSConnection("app.coqui.ai")
|
||||||
url = self.MODEL_ENDPOINTS[self.model]["list_speakers"]
|
url = self.MODEL_ENDPOINTS[self.model]["list_speakers"]
|
||||||
conn.request("GET", f"{url}?per_page=100", headers=self.headers)
|
conn.request("GET", f"{url}?page=1&per_page=100", headers=self.headers)
|
||||||
res = conn.getresponse()
|
res = conn.getresponse()
|
||||||
data = res.read()
|
data = res.read()
|
||||||
return [Speaker(s) for s in json.loads(data)["result"]]
|
return [Speaker(s) for s in json.loads(data)["result"]]
|
||||||
|
@ -148,7 +143,7 @@ class CS_API:
|
||||||
"""List custom voices created by the user."""
|
"""List custom voices created by the user."""
|
||||||
conn = http.client.HTTPSConnection("app.coqui.ai")
|
conn = http.client.HTTPSConnection("app.coqui.ai")
|
||||||
url = self.MODEL_ENDPOINTS[self.model]["list_voices"]
|
url = self.MODEL_ENDPOINTS[self.model]["list_voices"]
|
||||||
conn.request("GET", f"{url}", headers=self.headers)
|
conn.request("GET", f"{url}?page=1&per_page=100", headers=self.headers)
|
||||||
res = conn.getresponse()
|
res = conn.getresponse()
|
||||||
data = res.read()
|
data = res.read()
|
||||||
return [Speaker(s, True) for s in json.loads(data)["result"]]
|
return [Speaker(s, True) for s in json.loads(data)["result"]]
|
||||||
|
@ -197,14 +192,6 @@ class CS_API:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif model == "XTTS":
|
elif model == "XTTS":
|
||||||
payload.update(
|
|
||||||
{
|
|
||||||
"name": speaker.name,
|
|
||||||
"text": text,
|
|
||||||
"speed": speed,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif model == "XTTS-multilang":
|
|
||||||
payload.update(
|
payload.update(
|
||||||
{
|
{
|
||||||
"name": speaker.name,
|
"name": speaker.name,
|
||||||
|
@ -226,13 +213,10 @@ class CS_API:
|
||||||
assert language is None, "❗ language is not supported for V1 model."
|
assert language is None, "❗ language is not supported for V1 model."
|
||||||
elif self.model == "XTTS":
|
elif self.model == "XTTS":
|
||||||
assert emotion is None, f"❗ Emotions are not supported for XTTS model. Use V1 model."
|
assert emotion is None, f"❗ Emotions are not supported for XTTS model. Use V1 model."
|
||||||
assert language is None, "❗ Language is not supported for XTTS model. Use XTTS-multilang model."
|
assert language is not None, "❗ Language is required for XTTS model."
|
||||||
elif self.model == "XTTS-multilang":
|
|
||||||
assert emotion is None, f"❗ Emotions are not supported for XTTS-multilang model. Use V1 model."
|
|
||||||
assert language is not None, "❗ Language is required for XTTS-multilang model."
|
|
||||||
assert (
|
assert (
|
||||||
language in self.SUPPORTED_LANGUAGES
|
language in self.SUPPORTED_LANGUAGES
|
||||||
), f"❗ Language {language} is not yet supported. Use one of: en, es, de, fr, it, pt, pl"
|
), f"❗ Language {language} is not yet supported. Check https://docs.coqui.ai/reference/samples_xtts_create."
|
||||||
return text, speaker_name, speaker_id, emotion, speed, language
|
return text, speaker_name, speaker_id, emotion, speed, language
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
|
@ -255,7 +239,7 @@ class CS_API:
|
||||||
supported by `V1` model. Defaults to None.
|
supported by `V1` model. Defaults to None.
|
||||||
speed (float): Speed of the speech. 1.0 is normal speed.
|
speed (float): Speed of the speech. 1.0 is normal speed.
|
||||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||||
supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
supported by `XTTS` model. See https://docs.coqui.ai/reference/samples_xtts_create for supported languages.
|
||||||
"""
|
"""
|
||||||
self._check_token()
|
self._check_token()
|
||||||
self.ping_api()
|
self.ping_api()
|
||||||
|
@ -305,7 +289,7 @@ class CS_API:
|
||||||
speed (float): Speed of the speech. 1.0 is normal speed.
|
speed (float): Speed of the speech. 1.0 is normal speed.
|
||||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||||
supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
supported by `XTTS` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
|
||||||
file_path (str): Path to save the file. If None, a temporary file is created.
|
file_path (str): Path to save the file. If None, a temporary file is created.
|
||||||
"""
|
"""
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
|
@ -322,21 +306,12 @@ if __name__ == "__main__":
|
||||||
print(api.speakers)
|
print(api.speakers)
|
||||||
print(api.list_speakers_as_tts_models())
|
print(api.list_speakers_as_tts_models())
|
||||||
|
|
||||||
ts = time.time()
|
|
||||||
wav, sr = api.tts("It took me quite a long time to develop a voice.", speaker_name=api.speakers[0].name)
|
|
||||||
print(f" [i] XTTS took {time.time() - ts:.2f}s")
|
|
||||||
|
|
||||||
filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, file_path="output.wav")
|
|
||||||
|
|
||||||
api = CS_API(model="XTTS-multilang")
|
|
||||||
print(api.speakers)
|
|
||||||
|
|
||||||
ts = time.time()
|
ts = time.time()
|
||||||
wav, sr = api.tts(
|
wav, sr = api.tts(
|
||||||
"It took me quite a long time to develop a voice.", speaker_name=api.speakers[0].name, language="en"
|
"It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name
|
||||||
)
|
)
|
||||||
print(f" [i] XTTS took {time.time() - ts:.2f}s")
|
print(f" [i] XTTS took {time.time() - ts:.2f}s")
|
||||||
|
|
||||||
filepath = api.tts_to_file(
|
filepath = api.tts_to_file(
|
||||||
text="Hello world!", speaker_name=api.speakers[0].name, file_path="output.wav", language="en"
|
text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
import datetime
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
from TTS.utils.io import save_fsspec
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentWAV(object):
|
class AugmentWAV(object):
|
||||||
|
@ -118,11 +115,6 @@ class AugmentWAV(object):
|
||||||
return self.additive_noise(noise_type, audio)
|
return self.additive_noise(noise_type, audio)
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
|
||||||
text = text.capitalize()
|
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_encoder_model(config: "Coqpit"):
|
def setup_encoder_model(config: "Coqpit"):
|
||||||
if config.model_params["model_name"].lower() == "lstm":
|
if config.model_params["model_name"].lower() == "lstm":
|
||||||
model = LSTMSpeakerEncoder(
|
model = LSTMSpeakerEncoder(
|
||||||
|
@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
|
||||||
audio_config=config.audio,
|
audio_config=config.audio,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
|
||||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
|
||||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
|
||||||
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
"model": new_state_dict,
|
|
||||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
|
||||||
"criterion": criterion.state_dict(),
|
|
||||||
"step": current_step,
|
|
||||||
"epoch": epoch,
|
|
||||||
"loss": model_loss,
|
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
}
|
|
||||||
save_fsspec(state, checkpoint_path)
|
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
|
|
||||||
if model_loss < best_loss:
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
"model": new_state_dict,
|
|
||||||
"optimizer": optimizer.state_dict(),
|
|
||||||
"criterion": criterion.state_dict(),
|
|
||||||
"step": current_step,
|
|
||||||
"epoch": epoch,
|
|
||||||
"loss": model_loss,
|
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
}
|
|
||||||
best_loss = model_loss
|
|
||||||
bestmodel_path = "best_model.pth"
|
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
|
||||||
save_fsspec(state, bestmodel_path)
|
|
||||||
return best_loss
|
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
|
|
||||||
from TTS.utils.io import save_fsspec
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
|
||||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
|
||||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
|
||||||
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
"model": new_state_dict,
|
|
||||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
|
||||||
"step": current_step,
|
|
||||||
"loss": model_loss,
|
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
}
|
|
||||||
save_fsspec(state, checkpoint_path)
|
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
|
||||||
if model_loss < best_loss:
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
"model": new_state_dict,
|
|
||||||
"optimizer": optimizer.state_dict(),
|
|
||||||
"step": current_step,
|
|
||||||
"loss": model_loss,
|
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
}
|
|
||||||
best_loss = model_loss
|
|
||||||
bestmodel_path = "best_model.pth"
|
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
|
||||||
save_fsspec(state, bestmodel_path)
|
|
||||||
return best_loss
|
|
|
@ -3,13 +3,13 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer import TrainerArgs, get_last_checkpoint
|
from trainer import TrainerArgs, get_last_checkpoint
|
||||||
|
from trainer.io import copy_model_files
|
||||||
from trainer.logging import logger_factory
|
from trainer.logging import logger_factory
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||||
from TTS.utils.io import copy_model_files
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -30,35 +30,32 @@ class XttsConfig(BaseTTSConfig):
|
||||||
which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative),
|
which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative),
|
||||||
length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
|
length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
|
||||||
|
|
||||||
reperation_penalty (float):
|
repetition_penalty (float):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`.
|
The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`.
|
||||||
|
|
||||||
top_p (float):
|
top_p (float):
|
||||||
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
||||||
Defaults to `0.8`.
|
Defaults to `0.8`.
|
||||||
|
|
||||||
cond_free_k (float):
|
|
||||||
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
|
|
||||||
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
|
|
||||||
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
|
|
||||||
|
|
||||||
diffusion_temperature (float):
|
|
||||||
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
|
|
||||||
are the "mean" prediction of the diffusion network and will sound bland and smeared.
|
|
||||||
Defaults to `1.0`.
|
|
||||||
|
|
||||||
num_gpt_outputs (int):
|
num_gpt_outputs (int):
|
||||||
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
|
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
|
||||||
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
|
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
|
||||||
Defaults to `16`.
|
Defaults to `16`.
|
||||||
|
|
||||||
decoder_iterations (int):
|
gpt_cond_len (int):
|
||||||
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
|
Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
|
||||||
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
|
|
||||||
however. Defaults to `30`.
|
gpt_cond_chunk_len (int):
|
||||||
|
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
|
||||||
|
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
|
||||||
|
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
|
||||||
|
|
||||||
|
max_ref_len (int):
|
||||||
|
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
|
||||||
|
|
||||||
|
sound_norm_refs (bool):
|
||||||
|
Whether to normalize the conditioning audio. Defaults to `False`.
|
||||||
|
|
||||||
decoder_sampler (str):
|
|
||||||
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
|
|
||||||
Note:
|
Note:
|
||||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
@ -74,7 +71,24 @@ class XttsConfig(BaseTTSConfig):
|
||||||
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
|
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
|
||||||
model_dir: str = None
|
model_dir: str = None
|
||||||
languages: List[str] = field(
|
languages: List[str] = field(
|
||||||
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"]
|
default_factory=lambda: [
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"fr",
|
||||||
|
"de",
|
||||||
|
"it",
|
||||||
|
"pt",
|
||||||
|
"pl",
|
||||||
|
"tr",
|
||||||
|
"ru",
|
||||||
|
"nl",
|
||||||
|
"cs",
|
||||||
|
"ar",
|
||||||
|
"zh-cn",
|
||||||
|
"hu",
|
||||||
|
"ko",
|
||||||
|
"ja",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# inference params
|
# inference params
|
||||||
|
@ -83,8 +97,10 @@ class XttsConfig(BaseTTSConfig):
|
||||||
repetition_penalty: float = 2.0
|
repetition_penalty: float = 2.0
|
||||||
top_k: int = 50
|
top_k: int = 50
|
||||||
top_p: float = 0.85
|
top_p: float = 0.85
|
||||||
cond_free_k: float = 2.0
|
|
||||||
diffusion_temperature: float = 1.0
|
|
||||||
num_gpt_outputs: int = 1
|
num_gpt_outputs: int = 1
|
||||||
decoder_iterations: int = 30
|
|
||||||
decoder_sampler: str = "ddim"
|
# cloning
|
||||||
|
gpt_cond_len: int = 12
|
||||||
|
gpt_cond_chunk_len: int = 4
|
||||||
|
max_ref_len: int = 10
|
||||||
|
sound_norm_refs: bool = False
|
||||||
|
|
|
@ -280,7 +280,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, cols[0])
|
wav_file = os.path.join(root_path, cols[0])
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -294,7 +294,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
utt_id = line.split()[1]
|
utt_id = line.split()[1]
|
||||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||||
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import parametrize
|
||||||
|
|
||||||
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
||||||
|
|
||||||
|
@ -73,7 +74,7 @@ class ConvNorm(nn.Module):
|
||||||
)
|
)
|
||||||
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
|
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||||
if self.use_weight_norm:
|
if self.use_weight_norm:
|
||||||
self.conv = nn.utils.weight_norm(self.conv)
|
self.conv = nn.utils.parametrizations.weight_norm(self.conv)
|
||||||
|
|
||||||
def forward(self, signal, mask=None):
|
def forward(self, signal, mask=None):
|
||||||
conv_signal = self.conv(signal)
|
conv_signal = self.conv(signal)
|
||||||
|
@ -113,7 +114,7 @@ class ConvLSTMLinear(nn.Module):
|
||||||
dilation=1,
|
dilation=1,
|
||||||
w_init_gain="relu",
|
w_init_gain="relu",
|
||||||
)
|
)
|
||||||
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
|
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
|
||||||
convolutions.append(conv_layer)
|
convolutions.append(conv_layer)
|
||||||
|
|
||||||
self.convolutions = nn.ModuleList(convolutions)
|
self.convolutions = nn.ModuleList(convolutions)
|
||||||
|
@ -567,7 +568,7 @@ class LVCBlock(torch.nn.Module):
|
||||||
|
|
||||||
self.convt_pre = nn.Sequential(
|
self.convt_pre = nn.Sequential(
|
||||||
nn.LeakyReLU(lReLU_slope),
|
nn.LeakyReLU(lReLU_slope),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.ConvTranspose1d(
|
nn.ConvTranspose1d(
|
||||||
in_channels,
|
in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -584,7 +585,7 @@ class LVCBlock(torch.nn.Module):
|
||||||
self.conv_blocks.append(
|
self.conv_blocks.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.LeakyReLU(lReLU_slope),
|
nn.LeakyReLU(lReLU_slope),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
in_channels,
|
in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -665,6 +666,6 @@ class LVCBlock(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
self.kernel_predictor.remove_weight_norm()
|
self.kernel_predictor.remove_weight_norm()
|
||||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||||
for block in self.conv_blocks:
|
for block in self.conv_blocks:
|
||||||
nn.utils.remove_weight_norm(block[1])
|
parametrize.remove_parametrizations(block[1], "weight")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
|
from torch.nn.utils import parametrize
|
||||||
|
|
||||||
|
|
||||||
class KernelPredictor(nn.Module):
|
class KernelPredictor(nn.Module):
|
||||||
|
@ -36,7 +37,9 @@ class KernelPredictor(nn.Module):
|
||||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||||
|
|
||||||
self.input_conv = nn.Sequential(
|
self.input_conv = nn.Sequential(
|
||||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
nn.utils.parametrizations.weight_norm(
|
||||||
|
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||||
|
),
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +49,7 @@ class KernelPredictor(nn.Module):
|
||||||
self.residual_convs.append(
|
self.residual_convs.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Dropout(kpnet_dropout),
|
nn.Dropout(kpnet_dropout),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
|
@ -56,7 +59,7 @@ class KernelPredictor(nn.Module):
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
|
@ -68,7 +71,7 @@ class KernelPredictor(nn.Module):
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.kernel_conv = nn.utils.weight_norm(
|
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_kernel_channels,
|
kpnet_kernel_channels,
|
||||||
|
@ -77,7 +80,7 @@ class KernelPredictor(nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.bias_conv = nn.utils.weight_norm(
|
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_bias_channels,
|
kpnet_bias_channels,
|
||||||
|
@ -117,9 +120,9 @@ class KernelPredictor(nn.Module):
|
||||||
return kernels, bias
|
return kernels, bias
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||||
nn.utils.remove_weight_norm(self.bias_conv)
|
parametrize.remove_parametrizations(self.bias_conv, "weight")
|
||||||
for block in self.residual_convs:
|
for block in self.residual_convs:
|
||||||
nn.utils.remove_weight_norm(block[1])
|
parametrize.remove_parametrizations(block[1], "weight")
|
||||||
nn.utils.remove_weight_norm(block[3])
|
parametrize.remove_parametrizations(block[3], "weight")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.utils import parametrize
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
@ -62,7 +63,7 @@ class WN(torch.nn.Module):
|
||||||
# init conditioning layer
|
# init conditioning layer
|
||||||
if c_in_channels > 0:
|
if c_in_channels > 0:
|
||||||
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
|
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
|
||||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||||
# intermediate layers
|
# intermediate layers
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
dilation = dilation_rate**i
|
dilation = dilation_rate**i
|
||||||
|
@ -75,7 +76,7 @@ class WN(torch.nn.Module):
|
||||||
in_layer = torch.nn.Conv1d(
|
in_layer = torch.nn.Conv1d(
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||||
)
|
)
|
||||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||||
self.in_layers.append(in_layer)
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
if i < num_layers - 1:
|
if i < num_layers - 1:
|
||||||
|
@ -84,7 +85,7 @@ class WN(torch.nn.Module):
|
||||||
res_skip_channels = hidden_channels
|
res_skip_channels = hidden_channels
|
||||||
|
|
||||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||||
self.res_skip_layers.append(res_skip_layer)
|
self.res_skip_layers.append(res_skip_layer)
|
||||||
# setup weight norm
|
# setup weight norm
|
||||||
if not weight_norm:
|
if not weight_norm:
|
||||||
|
@ -115,11 +116,11 @@ class WN(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
if self.c_in_channels != 0:
|
if self.c_in_channels != 0:
|
||||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
parametrize.remove_parametrizations(self.cond_layer, "weight")
|
||||||
for l in self.in_layers:
|
for l in self.in_layers:
|
||||||
torch.nn.utils.remove_weight_norm(l)
|
parametrize.remove_parametrizations(l, "weight")
|
||||||
for l in self.res_skip_layers:
|
for l in self.res_skip_layers:
|
||||||
torch.nn.utils.remove_weight_norm(l)
|
parametrize.remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class WNBlocks(nn.Module):
|
class WNBlocks(nn.Module):
|
||||||
|
|
|
@ -186,7 +186,7 @@ class CouplingBlock(nn.Module):
|
||||||
self.sigmoid_scale = sigmoid_scale
|
self.sigmoid_scale = sigmoid_scale
|
||||||
# input layer
|
# input layer
|
||||||
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
||||||
start = torch.nn.utils.weight_norm(start)
|
start = torch.nn.utils.parametrizations.weight_norm(start)
|
||||||
self.start = start
|
self.start = start
|
||||||
# output layer
|
# output layer
|
||||||
# Initializing last layer to 0 makes the affine coupling layers
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
|
|
|
@ -13,12 +13,18 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
import torch as th
|
||||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||||
|
|
||||||
|
try:
|
||||||
|
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||||
|
|
||||||
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
||||||
|
except ImportError:
|
||||||
|
K_DIFFUSION_SAMPLERS = None
|
||||||
|
|
||||||
|
|
||||||
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -531,6 +537,8 @@ class GaussianDiffusion:
|
||||||
if self.conditioning_free is not True:
|
if self.conditioning_free is not True:
|
||||||
raise RuntimeError("cond_free must be true")
|
raise RuntimeError("cond_free must be true")
|
||||||
with tqdm(total=self.num_timesteps) as pbar:
|
with tqdm(total=self.num_timesteps) as pbar:
|
||||||
|
if K_DIFFUSION_SAMPLERS is None:
|
||||||
|
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
|
||||||
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("sampler not impl")
|
raise RuntimeError("sampler not impl")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
@ -6,6 +5,7 @@ from typing import Callable, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.utils.parametrize as parametrize
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
@ -44,7 +44,9 @@ class KernelPredictor(torch.nn.Module):
|
||||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||||
|
|
||||||
self.input_conv = nn.Sequential(
|
self.input_conv = nn.Sequential(
|
||||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
nn.utils.parametrizations.weight_norm(
|
||||||
|
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||||
|
),
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,7 +56,7 @@ class KernelPredictor(torch.nn.Module):
|
||||||
self.residual_convs.append(
|
self.residual_convs.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Dropout(kpnet_dropout),
|
nn.Dropout(kpnet_dropout),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
|
@ -64,7 +66,7 @@ class KernelPredictor(torch.nn.Module):
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
|
@ -76,7 +78,7 @@ class KernelPredictor(torch.nn.Module):
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.kernel_conv = nn.utils.weight_norm(
|
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_kernel_channels,
|
kpnet_kernel_channels,
|
||||||
|
@ -85,7 +87,7 @@ class KernelPredictor(torch.nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.bias_conv = nn.utils.weight_norm(
|
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels,
|
||||||
kpnet_bias_channels,
|
kpnet_bias_channels,
|
||||||
|
@ -125,12 +127,12 @@ class KernelPredictor(torch.nn.Module):
|
||||||
return kernels, bias
|
return kernels, bias
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||||
nn.utils.remove_weight_norm(self.bias_conv)
|
parametrize.remove_parametrizations(self.bias_conv)
|
||||||
for block in self.residual_convs:
|
for block in self.residual_convs:
|
||||||
nn.utils.remove_weight_norm(block[1])
|
parametrize.remove_parametrizations(block[1], "weight")
|
||||||
nn.utils.remove_weight_norm(block[3])
|
parametrize.remove_parametrizations(block[3], "weight")
|
||||||
|
|
||||||
|
|
||||||
class LVCBlock(torch.nn.Module):
|
class LVCBlock(torch.nn.Module):
|
||||||
|
@ -169,7 +171,7 @@ class LVCBlock(torch.nn.Module):
|
||||||
|
|
||||||
self.convt_pre = nn.Sequential(
|
self.convt_pre = nn.Sequential(
|
||||||
nn.LeakyReLU(lReLU_slope),
|
nn.LeakyReLU(lReLU_slope),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.ConvTranspose1d(
|
nn.ConvTranspose1d(
|
||||||
in_channels,
|
in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -186,7 +188,7 @@ class LVCBlock(torch.nn.Module):
|
||||||
self.conv_blocks.append(
|
self.conv_blocks.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.LeakyReLU(lReLU_slope),
|
nn.LeakyReLU(lReLU_slope),
|
||||||
nn.utils.weight_norm(
|
nn.utils.parametrizations.weight_norm(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
in_channels,
|
in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -267,9 +269,9 @@ class LVCBlock(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
self.kernel_predictor.remove_weight_norm()
|
self.kernel_predictor.remove_weight_norm()
|
||||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||||
for block in self.conv_blocks:
|
for block in self.conv_blocks:
|
||||||
nn.utils.remove_weight_norm(block[1])
|
parametrize.remove_parametrizations(block[1], "weight")
|
||||||
|
|
||||||
|
|
||||||
class UnivNetGenerator(nn.Module):
|
class UnivNetGenerator(nn.Module):
|
||||||
|
@ -314,11 +316,13 @@ class UnivNetGenerator(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
self.conv_pre = nn.utils.parametrizations.weight_norm(
|
||||||
|
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
|
||||||
|
)
|
||||||
|
|
||||||
self.conv_post = nn.Sequential(
|
self.conv_post = nn.Sequential(
|
||||||
nn.LeakyReLU(lReLU_slope),
|
nn.LeakyReLU(lReLU_slope),
|
||||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -346,11 +350,11 @@ class UnivNetGenerator(nn.Module):
|
||||||
self.remove_weight_norm()
|
self.remove_weight_norm()
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
parametrize.remove_parametrizations(self.conv_pre, "weight")
|
||||||
|
|
||||||
for layer in self.conv_post:
|
for layer in self.conv_post:
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
nn.utils.remove_weight_norm(layer)
|
parametrize.remove_parametrizations(layer, "weight")
|
||||||
|
|
||||||
for res_block in self.res_stack:
|
for res_block in self.res_stack:
|
||||||
res_block.remove_weight_norm()
|
res_block.remove_weight_norm()
|
||||||
|
|
|
@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,6 +11,7 @@ from transformers import GPT2Config
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||||
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
|
@ -105,6 +106,8 @@ class GPT(nn.Module):
|
||||||
checkpointing=False,
|
checkpointing=False,
|
||||||
average_conditioning_embeddings=False,
|
average_conditioning_embeddings=False,
|
||||||
label_smoothing=0.0,
|
label_smoothing=0.0,
|
||||||
|
use_perceiver_resampler=False,
|
||||||
|
perceiver_cond_length_compression=256,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -125,6 +128,7 @@ class GPT(nn.Module):
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
|
self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
|
||||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
|
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
|
||||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||||
self.max_prompt_tokens = max_prompt_tokens
|
self.max_prompt_tokens = max_prompt_tokens
|
||||||
|
@ -132,13 +136,12 @@ class GPT(nn.Module):
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.conditioning_dropout = nn.Dropout1d(0.1)
|
self.conditioning_dropout = nn.Dropout1d(0.1)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
|
self.use_perceiver_resampler = use_perceiver_resampler
|
||||||
|
self.perceiver_cond_length_compression = perceiver_cond_length_compression
|
||||||
|
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||||
|
|
||||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
|
||||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
self.gpt,
|
self.gpt,
|
||||||
self.mel_pos_embedding,
|
self.mel_pos_embedding,
|
||||||
|
@ -165,9 +168,29 @@ class GPT(nn.Module):
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||||
|
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
# XTTS v2
|
||||||
|
self.conditioning_perceiver = PerceiverResampler(
|
||||||
|
dim=model_dim,
|
||||||
|
depth=2,
|
||||||
|
dim_context=model_dim,
|
||||||
|
num_latents=32,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4,
|
||||||
|
use_flash_attn=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# XTTS v1
|
||||||
|
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||||
|
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||||
|
"conditioning_perceiver": list(self.conditioning_perceiver.parameters())
|
||||||
|
if self.use_perceiver_resampler
|
||||||
|
else None,
|
||||||
"gpt": list(self.gpt.parameters()),
|
"gpt": list(self.gpt.parameters()),
|
||||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||||
}
|
}
|
||||||
|
@ -250,9 +273,6 @@ class GPT(nn.Module):
|
||||||
if attn_mask_text is not None:
|
if attn_mask_text is not None:
|
||||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
if attn_mask_cond is not None:
|
|
||||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
|
||||||
else:
|
|
||||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||||
|
|
||||||
|
@ -318,7 +338,6 @@ class GPT(nn.Module):
|
||||||
prompt_len = 3
|
prompt_len = 3
|
||||||
prompt_len = prompt_len * 24 # in frames
|
prompt_len = prompt_len * 24 # in frames
|
||||||
if prompt_codes.shape[-1] >= prompt_len:
|
if prompt_codes.shape[-1] >= prompt_len:
|
||||||
new_prompt = []
|
|
||||||
for i in range(prompt_codes.shape[0]):
|
for i in range(prompt_codes.shape[0]):
|
||||||
if lengths[i] < prompt_len:
|
if lengths[i] < prompt_len:
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -340,7 +359,9 @@ class GPT(nn.Module):
|
||||||
if not return_latent:
|
if not return_latent:
|
||||||
if cond_input.ndim == 4:
|
if cond_input.ndim == 4:
|
||||||
cond_input = cond_input.squeeze(1)
|
cond_input = cond_input.squeeze(1)
|
||||||
conds = self.conditioning_encoder(cond_input)
|
conds = self.conditioning_encoder(cond_input) # (b, d, s)
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
|
||||||
else:
|
else:
|
||||||
# already computed
|
# already computed
|
||||||
conds = cond_input.unsqueeze(1)
|
conds = cond_input.unsqueeze(1)
|
||||||
|
@ -354,6 +375,7 @@ class GPT(nn.Module):
|
||||||
wav_lengths,
|
wav_lengths,
|
||||||
cond_mels=None,
|
cond_mels=None,
|
||||||
cond_idxs=None,
|
cond_idxs=None,
|
||||||
|
cond_lens=None,
|
||||||
cond_latents=None,
|
cond_latents=None,
|
||||||
return_attentions=False,
|
return_attentions=False,
|
||||||
return_latent=False,
|
return_latent=False,
|
||||||
|
@ -379,10 +401,24 @@ class GPT(nn.Module):
|
||||||
max_text_len = text_lengths.max()
|
max_text_len = text_lengths.max()
|
||||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||||
|
|
||||||
|
if cond_lens is not None:
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
cond_lens = cond_lens // self.perceiver_cond_length_compression
|
||||||
|
else:
|
||||||
|
cond_lens = cond_lens // self.code_stride_len
|
||||||
|
|
||||||
if cond_idxs is not None:
|
if cond_idxs is not None:
|
||||||
# recompute cond idxs for mel lengths
|
# recompute cond idxs for mel lengths
|
||||||
for idx, l in enumerate(code_lengths):
|
for idx in range(cond_idxs.size(0)):
|
||||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
if self.use_perceiver_resampler:
|
||||||
|
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
|
||||||
|
else:
|
||||||
|
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
|
||||||
|
|
||||||
|
# ensure that the cond_mel does not have padding
|
||||||
|
# if cond_lens is not None and cond_idxs is None:
|
||||||
|
# min_cond_len = torch.min(cond_lens)
|
||||||
|
# cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||||
|
|
||||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||||
max_mel_len = code_lengths.max()
|
max_mel_len = code_lengths.max()
|
||||||
|
@ -390,15 +426,6 @@ class GPT(nn.Module):
|
||||||
if max_mel_len > audio_codes.shape[-1]:
|
if max_mel_len > audio_codes.shape[-1]:
|
||||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||||
|
|
||||||
silence = True
|
|
||||||
for idx, l in enumerate(code_lengths):
|
|
||||||
length = l.item()
|
|
||||||
while silence:
|
|
||||||
if audio_codes[idx, length - 1] != 83:
|
|
||||||
break
|
|
||||||
length -= 1
|
|
||||||
code_lengths[idx] = length
|
|
||||||
|
|
||||||
# 💖 Lovely assertions
|
# 💖 Lovely assertions
|
||||||
assert (
|
assert (
|
||||||
max_mel_len <= audio_codes.shape[-1]
|
max_mel_len <= audio_codes.shape[-1]
|
||||||
|
@ -414,7 +441,9 @@ class GPT(nn.Module):
|
||||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||||
|
|
||||||
# Pad mel codes with stop_audio_token
|
# Pad mel codes with stop_audio_token
|
||||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
|
audio_codes = self.set_mel_padding(
|
||||||
|
audio_codes, code_lengths - 3
|
||||||
|
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||||
|
|
||||||
# Build input and target tensors
|
# Build input and target tensors
|
||||||
# Prepend start token to inputs and append stop token to targets
|
# Prepend start token to inputs and append stop token to targets
|
||||||
|
@ -450,9 +479,13 @@ class GPT(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if cond_idxs is not None:
|
if cond_idxs is not None:
|
||||||
|
# use masking approach
|
||||||
for idx, r in enumerate(cond_idxs):
|
for idx, r in enumerate(cond_idxs):
|
||||||
l = r[1] - r[0]
|
l = r[1] - r[0]
|
||||||
attn_mask_cond[idx, l:] = 0.0
|
attn_mask_cond[idx, l:] = 0.0
|
||||||
|
elif cond_lens is not None:
|
||||||
|
for idx, l in enumerate(cond_lens):
|
||||||
|
attn_mask_cond[idx, l:] = 0.0
|
||||||
|
|
||||||
for idx, l in enumerate(text_lengths):
|
for idx, l in enumerate(text_lengths):
|
||||||
attn_mask_text[idx, l + 1 :] = 0.0
|
attn_mask_text[idx, l + 1 :] = 0.0
|
||||||
|
@ -523,7 +556,7 @@ class GPT(nn.Module):
|
||||||
|
|
||||||
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
||||||
self.compute_embeddings(cond_latents, text_inputs)
|
self.compute_embeddings(cond_latents, text_inputs)
|
||||||
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs)
|
return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -559,7 +592,7 @@ class GPT(nn.Module):
|
||||||
bos_token_id=self.start_audio_token,
|
bos_token_id=self.start_audio_token,
|
||||||
pad_token_id=self.stop_audio_token,
|
pad_token_id=self.stop_audio_token,
|
||||||
eos_token_id=self.stop_audio_token,
|
eos_token_id=self.stop_audio_token,
|
||||||
max_length=self.max_mel_tokens,
|
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||||
|
@ -572,7 +605,7 @@ class GPT(nn.Module):
|
||||||
bos_token_id=self.start_audio_token,
|
bos_token_id=self.start_audio_token,
|
||||||
pad_token_id=self.stop_audio_token,
|
pad_token_id=self.stop_audio_token,
|
||||||
eos_token_id=self.stop_audio_token,
|
eos_token_id=self.stop_audio_token,
|
||||||
max_length=self.max_mel_tokens,
|
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
|
||||||
do_stream=True,
|
do_stream=True,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,7 +3,8 @@ import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -120,9 +121,9 @@ class ResBlock1(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs1:
|
for l in self.convs1:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.convs2:
|
for l in self.convs2:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class ResBlock2(torch.nn.Module):
|
class ResBlock2(torch.nn.Module):
|
||||||
|
@ -176,7 +177,7 @@ class ResBlock2(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class HifiganGenerator(torch.nn.Module):
|
class HifiganGenerator(torch.nn.Module):
|
||||||
|
@ -251,10 +252,10 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
if not conv_pre_weight_norm:
|
if not conv_pre_weight_norm:
|
||||||
remove_weight_norm(self.conv_pre)
|
remove_parametrizations(self.conv_pre, "weight")
|
||||||
|
|
||||||
if not conv_post_weight_norm:
|
if not conv_post_weight_norm:
|
||||||
remove_weight_norm(self.conv_post)
|
remove_parametrizations(self.conv_post, "weight")
|
||||||
|
|
||||||
if self.cond_in_each_up_layer:
|
if self.cond_in_each_up_layer:
|
||||||
self.conds = nn.ModuleList()
|
self.conds = nn.ModuleList()
|
||||||
|
@ -317,11 +318,11 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
print("Removing weight norm...")
|
print("Removing weight norm...")
|
||||||
for l in self.ups:
|
for l in self.ups:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.resblocks:
|
for l in self.resblocks:
|
||||||
l.remove_weight_norm()
|
l.remove_weight_norm()
|
||||||
remove_weight_norm(self.conv_pre)
|
remove_parametrizations(self.conv_pre, "weight")
|
||||||
remove_weight_norm(self.conv_post)
|
remove_parametrizations(self.conv_post, "weight")
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False, cache=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
|
|
|
@ -0,0 +1,319 @@
|
||||||
|
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from packaging import version
|
||||||
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def once(fn):
|
||||||
|
called = False
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
nonlocal called
|
||||||
|
if called:
|
||||||
|
return
|
||||||
|
called = True
|
||||||
|
return fn(x)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
print_once = once(print)
|
||||||
|
|
||||||
|
# main class
|
||||||
|
|
||||||
|
|
||||||
|
class Attend(nn.Module):
|
||||||
|
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.register_buffer("mask", None, persistent=False)
|
||||||
|
|
||||||
|
self.use_flash = use_flash
|
||||||
|
assert not (
|
||||||
|
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
||||||
|
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
||||||
|
|
||||||
|
# determine efficient attention configs for cuda and cpu
|
||||||
|
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
||||||
|
self.cpu_config = self.config(True, True, True)
|
||||||
|
self.cuda_config = None
|
||||||
|
|
||||||
|
if not torch.cuda.is_available() or not use_flash:
|
||||||
|
return
|
||||||
|
|
||||||
|
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
||||||
|
|
||||||
|
if device_properties.major == 8 and device_properties.minor == 0:
|
||||||
|
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
|
||||||
|
self.cuda_config = self.config(True, False, False)
|
||||||
|
else:
|
||||||
|
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
|
||||||
|
self.cuda_config = self.config(False, True, True)
|
||||||
|
|
||||||
|
def get_mask(self, n, device):
|
||||||
|
if exists(self.mask) and self.mask.shape[-1] >= n:
|
||||||
|
return self.mask[:n, :n]
|
||||||
|
|
||||||
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def flash_attn(self, q, k, v, mask=None):
|
||||||
|
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
||||||
|
|
||||||
|
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||||
|
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||||
|
|
||||||
|
if k.ndim == 3:
|
||||||
|
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
||||||
|
|
||||||
|
if v.ndim == 3:
|
||||||
|
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
||||||
|
|
||||||
|
# Check if mask exists and expand to compatible shape
|
||||||
|
# The mask is B L, so it would have to be expanded to B H N L
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, "b j -> b 1 1 j")
|
||||||
|
mask = mask.expand(-1, heads, q_len, -1)
|
||||||
|
|
||||||
|
# Check if there is a compatible device for flash attention
|
||||||
|
|
||||||
|
config = self.cuda_config if is_cuda else self.cpu_config
|
||||||
|
|
||||||
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, q, k, v, mask=None):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
if self.use_flash:
|
||||||
|
return self.flash_attn(q, k, v, mask=mask)
|
||||||
|
|
||||||
|
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
||||||
|
|
||||||
|
# key padding mask
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, "b j -> b 1 1 j")
|
||||||
|
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# causal mask
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
causal_mask = self.get_mask(n, device)
|
||||||
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def Sequential(*mods):
|
||||||
|
return nn.Sequential(*filter(exists, mods))
|
||||||
|
|
||||||
|
|
||||||
|
def exists(x):
|
||||||
|
return x is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if callable(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, scale=True, dim_cond=None):
|
||||||
|
super().__init__()
|
||||||
|
self.cond = exists(dim_cond)
|
||||||
|
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
||||||
|
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
||||||
|
|
||||||
|
def forward(self, x, cond=None):
|
||||||
|
gamma = default(self.gamma, 1)
|
||||||
|
out = F.normalize(x, dim=-1) * self.scale * gamma
|
||||||
|
|
||||||
|
if not self.cond:
|
||||||
|
return out
|
||||||
|
|
||||||
|
assert exists(cond)
|
||||||
|
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
||||||
|
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
||||||
|
return out * gamma + beta
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Conv1d):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
(kernel_size,) = self.kernel_size
|
||||||
|
(dilation,) = self.dilation
|
||||||
|
(stride,) = self.stride
|
||||||
|
|
||||||
|
assert stride == 1
|
||||||
|
self.causal_padding = dilation * (kernel_size - 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
|
return super().forward(causal_padded_x)
|
||||||
|
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.gelu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
def FeedForward(dim, mult=4, causal_conv=False):
|
||||||
|
dim_inner = int(dim * mult * 2 / 3)
|
||||||
|
|
||||||
|
conv = None
|
||||||
|
if causal_conv:
|
||||||
|
conv = nn.Sequential(
|
||||||
|
Rearrange("b n d -> b d n"),
|
||||||
|
CausalConv1d(dim_inner, dim_inner, 3),
|
||||||
|
Rearrange("b d n -> b n d"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverResampler(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=2,
|
||||||
|
dim_context=None,
|
||||||
|
num_latents=32,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4,
|
||||||
|
use_flash_attn=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
dim_context = default(dim_context, dim)
|
||||||
|
|
||||||
|
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
||||||
|
|
||||||
|
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||||
|
nn.init.normal_(self.latents, std=0.02)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_head=dim_head,
|
||||||
|
heads=heads,
|
||||||
|
use_flash=use_flash_attn,
|
||||||
|
cross_attn_include_queries=True,
|
||||||
|
),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
x = self.proj_context(x)
|
||||||
|
|
||||||
|
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(latents, x, mask=mask) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
return self.norm(latents)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
dim_context=None,
|
||||||
|
causal=False,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
dropout=0.0,
|
||||||
|
use_flash=False,
|
||||||
|
cross_attn_include_queries=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.heads = heads
|
||||||
|
self.cross_attn_include_queries = cross_attn_include_queries
|
||||||
|
|
||||||
|
dim_inner = dim_head * heads
|
||||||
|
dim_context = default(dim_context, dim)
|
||||||
|
|
||||||
|
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
||||||
|
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
h, has_context = self.heads, exists(context)
|
||||||
|
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
if has_context and self.cross_attn_include_queries:
|
||||||
|
context = torch.cat((x, context), dim=-2)
|
||||||
|
|
||||||
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
|
out = self.attend(q, k, v, mask=mask)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
return self.to_out(out)
|
|
@ -1,14 +1,73 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import textwrap
|
||||||
|
from functools import cached_property
|
||||||
import torch
|
|
||||||
from tokenizers import Tokenizer
|
|
||||||
|
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
import torch
|
||||||
|
from hangul_romanize import Transliter
|
||||||
|
from hangul_romanize.rule import academic
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
from spacy.lang.ar import Arabic
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.es import Spanish
|
||||||
|
from spacy.lang.ja import Japanese
|
||||||
|
from spacy.lang.zh import Chinese
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||||
|
|
||||||
|
|
||||||
|
def get_spacy_lang(lang):
|
||||||
|
if lang == "zh":
|
||||||
|
return Chinese()
|
||||||
|
elif lang == "ja":
|
||||||
|
return Japanese()
|
||||||
|
elif lang == "ar":
|
||||||
|
return Arabic()
|
||||||
|
elif lang == "es":
|
||||||
|
return Spanish()
|
||||||
|
else:
|
||||||
|
# For most languages, Enlish does the job
|
||||||
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def split_sentence(text, lang, text_split_length=250):
|
||||||
|
"""Preprocess the input text"""
|
||||||
|
text_splits = []
|
||||||
|
if text_split_length is not None and len(text) >= text_split_length:
|
||||||
|
text_splits.append("")
|
||||||
|
nlp = get_spacy_lang(lang)
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
doc = nlp(text)
|
||||||
|
for sentence in doc.sents:
|
||||||
|
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||||
|
# if the last sentence + the current sentence is less than the text_split_length
|
||||||
|
# then add the current sentence to the last sentence
|
||||||
|
text_splits[-1] += " " + str(sentence)
|
||||||
|
text_splits[-1] = text_splits[-1].lstrip()
|
||||||
|
elif len(str(sentence)) > text_split_length:
|
||||||
|
# if the current sentence is greater than the text_split_length
|
||||||
|
for line in textwrap.wrap(
|
||||||
|
str(sentence),
|
||||||
|
width=text_split_length,
|
||||||
|
drop_whitespace=True,
|
||||||
|
break_on_hyphens=False,
|
||||||
|
tabsize=1,
|
||||||
|
):
|
||||||
|
text_splits.append(str(line))
|
||||||
|
else:
|
||||||
|
text_splits.append(str(sentence))
|
||||||
|
|
||||||
|
if len(text_splits) > 1:
|
||||||
|
if text_splits[0] == "":
|
||||||
|
del text_splits[0]
|
||||||
|
else:
|
||||||
|
text_splits = [text.lstrip()]
|
||||||
|
|
||||||
|
return text_splits
|
||||||
|
|
||||||
|
|
||||||
_whitespace_re = re.compile(r"\s+")
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
@ -112,7 +171,7 @@ _abbreviations = {
|
||||||
# There are not many common abbreviations in Arabic as in English.
|
# There are not many common abbreviations in Arabic as in English.
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh-cn": [
|
"zh": [
|
||||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||||
|
@ -155,15 +214,32 @@ _abbreviations = {
|
||||||
# Add other Turkish abbreviations here if needed.
|
# Add other Turkish abbreviations here if needed.
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
|
"hu": [
|
||||||
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("dr", "doktor"), # doctor
|
||||||
|
("b", "bácsi"), # Mr.
|
||||||
|
("nőv", "nővér"), # nurse
|
||||||
|
# Add other Hungarian abbreviations here if needed.
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"ko": [
|
||||||
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||||
|
]
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def expand_abbreviations_multilingual(text, lang='en'):
|
|
||||||
|
def expand_abbreviations_multilingual(text, lang="en"):
|
||||||
for regex, replacement in _abbreviations[lang]:
|
for regex, replacement in _abbreviations[lang]:
|
||||||
text = re.sub(regex, replacement, text)
|
text = re.sub(regex, replacement, text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
_symbols_multilingual = {
|
_symbols_multilingual = {
|
||||||
'en': [
|
"en": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " and "),
|
("&", " and "),
|
||||||
|
@ -172,10 +248,10 @@ _symbols_multilingual = {
|
||||||
("#", " hash "),
|
("#", " hash "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pound "),
|
("£", " pound "),
|
||||||
("°", " degree ")
|
("°", " degree "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'es': [
|
"es": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " y "),
|
("&", " y "),
|
||||||
|
@ -184,10 +260,10 @@ _symbols_multilingual = {
|
||||||
("#", " numeral "),
|
("#", " numeral "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " grados ")
|
("°", " grados "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'fr': [
|
"fr": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " et "),
|
("&", " et "),
|
||||||
|
@ -196,10 +272,10 @@ _symbols_multilingual = {
|
||||||
("#", " dièse "),
|
("#", " dièse "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " livre "),
|
("£", " livre "),
|
||||||
("°", " degrés ")
|
("°", " degrés "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'de': [
|
"de": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " und "),
|
("&", " und "),
|
||||||
|
@ -208,10 +284,10 @@ _symbols_multilingual = {
|
||||||
("#", " raute "),
|
("#", " raute "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pfund "),
|
("£", " pfund "),
|
||||||
("°", " grad ")
|
("°", " grad "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'pt': [
|
"pt": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " e "),
|
("&", " e "),
|
||||||
|
@ -220,10 +296,10 @@ _symbols_multilingual = {
|
||||||
("#", " cardinal "),
|
("#", " cardinal "),
|
||||||
("$", " dólar "),
|
("$", " dólar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " graus ")
|
("°", " graus "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'it': [
|
"it": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " e "),
|
("&", " e "),
|
||||||
|
@ -232,10 +308,10 @@ _symbols_multilingual = {
|
||||||
("#", " cancelletto "),
|
("#", " cancelletto "),
|
||||||
("$", " dollaro "),
|
("$", " dollaro "),
|
||||||
("£", " sterlina "),
|
("£", " sterlina "),
|
||||||
("°", " gradi ")
|
("°", " gradi "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'pl': [
|
"pl": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " i "),
|
("&", " i "),
|
||||||
|
@ -244,7 +320,7 @@ _symbols_multilingual = {
|
||||||
("#", " krzyżyk "),
|
("#", " krzyżyk "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " funt "),
|
("£", " funt "),
|
||||||
("°", " stopnie ")
|
("°", " stopnie "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"ar": [
|
"ar": [
|
||||||
|
@ -257,10 +333,10 @@ _symbols_multilingual = {
|
||||||
("#", " رقم "),
|
("#", " رقم "),
|
||||||
("$", " دولار "),
|
("$", " دولار "),
|
||||||
("£", " جنيه "),
|
("£", " جنيه "),
|
||||||
("°", " درجة ")
|
("°", " درجة "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh-cn": [
|
"zh": [
|
||||||
# Chinese
|
# Chinese
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -270,7 +346,7 @@ _symbols_multilingual = {
|
||||||
("#", " 号 "),
|
("#", " 号 "),
|
||||||
("$", " 美元 "),
|
("$", " 美元 "),
|
||||||
("£", " 英镑 "),
|
("£", " 英镑 "),
|
||||||
("°", " 度 ")
|
("°", " 度 "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"cs": [
|
"cs": [
|
||||||
|
@ -283,7 +359,7 @@ _symbols_multilingual = {
|
||||||
("#", " křížek "),
|
("#", " křížek "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " stupně ")
|
("°", " stupně "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"ru": [
|
"ru": [
|
||||||
|
@ -296,7 +372,7 @@ _symbols_multilingual = {
|
||||||
("#", " номер "),
|
("#", " номер "),
|
||||||
("$", " доллар "),
|
("$", " доллар "),
|
||||||
("£", " фунт "),
|
("£", " фунт "),
|
||||||
("°", " градус ")
|
("°", " градус "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"nl": [
|
"nl": [
|
||||||
|
@ -309,7 +385,7 @@ _symbols_multilingual = {
|
||||||
("#", " hekje "),
|
("#", " hekje "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pond "),
|
("£", " pond "),
|
||||||
("°", " graden ")
|
("°", " graden "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"tr": [
|
"tr": [
|
||||||
|
@ -321,15 +397,41 @@ _symbols_multilingual = {
|
||||||
("#", " diyez "),
|
("#", " diyez "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " sterlin "),
|
("£", " sterlin "),
|
||||||
("°", " derece ")
|
("°", " derece "),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"hu": [
|
||||||
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("&", " és "),
|
||||||
|
("@", " kukac "),
|
||||||
|
("%", " százalék "),
|
||||||
|
("#", " kettőskereszt "),
|
||||||
|
("$", " dollár "),
|
||||||
|
("£", " font "),
|
||||||
|
("°", " fok "),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"ko": [
|
||||||
|
# Korean
|
||||||
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("&", " 그리고 "),
|
||||||
|
("@", " 에 "),
|
||||||
|
("%", " 퍼센트 "),
|
||||||
|
("#", " 번호 "),
|
||||||
|
("$", " 달러 "),
|
||||||
|
("£", " 파운드 "),
|
||||||
|
("°", " 도 "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def expand_symbols_multilingual(text, lang='en'):
|
|
||||||
|
def expand_symbols_multilingual(text, lang="en"):
|
||||||
for regex, replacement in _symbols_multilingual[lang]:
|
for regex, replacement in _symbols_multilingual[lang]:
|
||||||
text = re.sub(regex, replacement, text)
|
text = re.sub(regex, replacement, text)
|
||||||
text = text.replace(' ', ' ') # Ensure there are no double spaces
|
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -346,37 +448,43 @@ _ordinal_re = {
|
||||||
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
||||||
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
||||||
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
||||||
|
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
|
||||||
|
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
|
||||||
}
|
}
|
||||||
_number_re = re.compile(r"[0-9]+")
|
_number_re = re.compile(r"[0-9]+")
|
||||||
_currency_re = {
|
_currency_re = {
|
||||||
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||||
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||||
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
|
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
def _remove_commas(m):
|
||||||
text = m.group(0)
|
text = m.group(0)
|
||||||
if "," in text:
|
if "," in text:
|
||||||
text = text.replace(",", "")
|
text = text.replace(",", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _remove_dots(m):
|
def _remove_dots(m):
|
||||||
text = m.group(0)
|
text = m.group(0)
|
||||||
if "." in text:
|
if "." in text:
|
||||||
text = text.replace(".", "")
|
text = text.replace(".", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _expand_decimal_point(m, lang='en'):
|
|
||||||
|
def _expand_decimal_point(m, lang="en"):
|
||||||
amount = m.group(1).replace(",", ".")
|
amount = m.group(1).replace(",", ".")
|
||||||
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def _expand_currency(m, lang='en', currency='USD'):
|
|
||||||
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
|
def _expand_currency(m, lang="en", currency="USD"):
|
||||||
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
|
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||||
|
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
and_equivalents = {
|
and_equivalents = {
|
||||||
"en": ", ",
|
"en": ", ",
|
||||||
|
@ -391,6 +499,8 @@ def _expand_currency(m, lang='en', currency='USD'):
|
||||||
"nl": ", ",
|
"nl": ", ",
|
||||||
"ar": ", ",
|
"ar": ", ",
|
||||||
"tr": ", ",
|
"tr": ", ",
|
||||||
|
"hu": ", ",
|
||||||
|
"ko": ", ",
|
||||||
}
|
}
|
||||||
|
|
||||||
if amount.is_integer():
|
if amount.is_integer():
|
||||||
|
@ -400,14 +510,17 @@ def _expand_currency(m, lang='en', currency='USD'):
|
||||||
|
|
||||||
return full_amount
|
return full_amount
|
||||||
|
|
||||||
def _expand_ordinal(m, lang='en'):
|
|
||||||
|
def _expand_ordinal(m, lang="en"):
|
||||||
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def _expand_number(m, lang='en'):
|
|
||||||
|
def _expand_number(m, lang="en"):
|
||||||
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def expand_numbers_multilingual(text, lang='en'):
|
|
||||||
if lang == "zh-cn":
|
def expand_numbers_multilingual(text, lang="en"):
|
||||||
|
if lang == "zh":
|
||||||
text = zh_num2words()(text)
|
text = zh_num2words()(text)
|
||||||
else:
|
else:
|
||||||
if lang in ["en", "ru"]:
|
if lang in ["en", "ru"]:
|
||||||
|
@ -415,9 +528,9 @@ def expand_numbers_multilingual(text, lang='en'):
|
||||||
else:
|
else:
|
||||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||||
try:
|
try:
|
||||||
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
|
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||||
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
|
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||||
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
|
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if lang != "tr":
|
if lang != "tr":
|
||||||
|
@ -426,14 +539,17 @@ def expand_numbers_multilingual(text, lang='en'):
|
||||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def lowercase(text):
|
def lowercase(text):
|
||||||
return text.lower()
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
def collapse_whitespace(text):
|
def collapse_whitespace(text):
|
||||||
return re.sub(_whitespace_re, " ", text)
|
return re.sub(_whitespace_re, " ", text)
|
||||||
|
|
||||||
|
|
||||||
def multilingual_cleaners(text, lang):
|
def multilingual_cleaners(text, lang):
|
||||||
text = text.replace('"', '')
|
text = text.replace('"', "")
|
||||||
if lang == "tr":
|
if lang == "tr":
|
||||||
text = text.replace("İ", "i")
|
text = text.replace("İ", "i")
|
||||||
text = text.replace("Ö", "ö")
|
text = text.replace("Ö", "ö")
|
||||||
|
@ -445,55 +561,90 @@ def multilingual_cleaners(text, lang):
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def basic_cleaners(text):
|
def basic_cleaners(text):
|
||||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def chinese_transliterate(text):
|
def chinese_transliterate(text):
|
||||||
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
|
return "".join(
|
||||||
|
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def japanese_cleaners(text, katsu):
|
def japanese_cleaners(text, katsu):
|
||||||
text = katsu.romaji(text)
|
text = katsu.romaji(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def korean_transliterate(text):
|
||||||
|
r = Transliter(academic)
|
||||||
|
return r.translit(text)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json")
|
||||||
|
|
||||||
|
|
||||||
class VoiceBpeTokenizer:
|
class VoiceBpeTokenizer:
|
||||||
def __init__(self, vocab_file=None, preprocess=None):
|
def __init__(self, vocab_file=None):
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.katsu = None
|
|
||||||
|
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
with open(vocab_file, "r", encoding="utf-8") as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
|
|
||||||
self.language = vocab["model"]["language"] if "language" in vocab["model"] else None
|
|
||||||
|
|
||||||
if preprocess is None:
|
|
||||||
self.preprocess = "pre_tokenizer" in vocab and vocab["pre_tokenizer"]
|
|
||||||
else:
|
|
||||||
self.preprocess = preprocess
|
|
||||||
|
|
||||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||||
|
self.char_limits = {
|
||||||
|
"en": 250,
|
||||||
|
"de": 253,
|
||||||
|
"fr": 273,
|
||||||
|
"es": 239,
|
||||||
|
"it": 213,
|
||||||
|
"pt": 203,
|
||||||
|
"pl": 224,
|
||||||
|
"zh": 82,
|
||||||
|
"ar": 166,
|
||||||
|
"cs": 186,
|
||||||
|
"ru": 182,
|
||||||
|
"nl": 251,
|
||||||
|
"tr": 226,
|
||||||
|
"ja": 71,
|
||||||
|
"hu": 224,
|
||||||
|
"ko": 95,
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def katsu(self):
|
||||||
|
import cutlet
|
||||||
|
|
||||||
|
return cutlet.Cutlet()
|
||||||
|
|
||||||
|
def check_input_length(self, txt, lang):
|
||||||
|
lang = lang.split("-")[0] # remove the region
|
||||||
|
limit = self.char_limits.get(lang, 250)
|
||||||
|
if len(txt) > limit:
|
||||||
|
print(
|
||||||
|
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess_text(self, txt, lang):
|
def preprocess_text(self, txt, lang):
|
||||||
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
|
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
||||||
txt = multilingual_cleaners(txt, lang)
|
txt = multilingual_cleaners(txt, lang)
|
||||||
if lang == "zh-cn":
|
if lang == "zh":
|
||||||
txt = chinese_transliterate(txt)
|
txt = chinese_transliterate(txt)
|
||||||
|
if lang == "ko":
|
||||||
|
txt = korean_transliterate(txt)
|
||||||
elif lang == "ja":
|
elif lang == "ja":
|
||||||
if self.katsu is None:
|
|
||||||
import cutlet
|
|
||||||
self.katsu = cutlet.Cutlet()
|
|
||||||
txt = japanese_cleaners(txt, self.katsu)
|
txt = japanese_cleaners(txt, self.katsu)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def encode(self, txt, lang):
|
def encode(self, txt, lang):
|
||||||
if self.preprocess:
|
lang = lang.split("-")[0] # remove the region
|
||||||
|
self.check_input_length(txt, lang)
|
||||||
txt = self.preprocess_text(txt, lang)
|
txt = self.preprocess_text(txt, lang)
|
||||||
|
lang = "zh-cn" if lang == "zh" else lang
|
||||||
txt = f"[{lang}]{txt}"
|
txt = f"[{lang}]{txt}"
|
||||||
txt = txt.replace(" ", "[SPACE]")
|
txt = txt.replace(" ", "[SPACE]")
|
||||||
return self.tokenizer.encode(txt).ids
|
return self.tokenizer.encode(txt).ids
|
||||||
|
@ -512,3 +663,178 @@ class VoiceBpeTokenizer:
|
||||||
|
|
||||||
def get_number_tokens(self):
|
def get_number_tokens(self):
|
||||||
return max(self.tokenizer.get_vocab().values()) + 1
|
return max(self.tokenizer.get_vocab().values()) + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_expand_numbers_multilingual():
|
||||||
|
test_cases = [
|
||||||
|
# English
|
||||||
|
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
|
||||||
|
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
|
||||||
|
("This is a 1st test", "This is a first test", "en"),
|
||||||
|
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
|
||||||
|
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
|
||||||
|
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
|
||||||
|
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
|
||||||
|
# French
|
||||||
|
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
|
||||||
|
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
|
||||||
|
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
|
||||||
|
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
|
||||||
|
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
|
||||||
|
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
|
||||||
|
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
|
||||||
|
# German
|
||||||
|
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
|
||||||
|
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
|
||||||
|
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
|
||||||
|
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
|
||||||
|
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
|
||||||
|
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
|
||||||
|
# Spanish
|
||||||
|
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
|
||||||
|
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
|
||||||
|
("Este es un 1er test", "Este es un primero test", "es"),
|
||||||
|
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
|
||||||
|
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
|
||||||
|
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
|
||||||
|
# Italian
|
||||||
|
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
|
||||||
|
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
|
||||||
|
("Questo è un 1° test", "Questo è un primo test", "it"),
|
||||||
|
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
|
||||||
|
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
|
||||||
|
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
|
||||||
|
# Portuguese
|
||||||
|
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
|
||||||
|
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
|
||||||
|
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
|
||||||
|
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
|
||||||
|
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
|
||||||
|
(
|
||||||
|
"Isso custará 20,15€ senhor.",
|
||||||
|
"Isso custará vinte euros e quinze cêntimos senhor.",
|
||||||
|
"pt",
|
||||||
|
), # "cêntimos" should be "centavos" num2words issue
|
||||||
|
# Polish
|
||||||
|
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
|
||||||
|
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
|
||||||
|
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
|
||||||
|
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
|
||||||
|
# Arabic
|
||||||
|
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
|
||||||
|
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
|
||||||
|
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
|
||||||
|
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
|
||||||
|
# Czech
|
||||||
|
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
|
||||||
|
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
|
||||||
|
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
|
||||||
|
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
|
||||||
|
# Russian
|
||||||
|
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
|
||||||
|
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
|
||||||
|
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
|
||||||
|
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
|
||||||
|
# Dutch
|
||||||
|
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
|
||||||
|
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
|
||||||
|
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||||
|
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||||
|
# Chinese (Simplified)
|
||||||
|
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||||
|
("有50名士兵", "有五十名士兵", "zh"),
|
||||||
|
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||||
|
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||||
|
# Turkish
|
||||||
|
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
|
||||||
|
("50 asker vardı.", "elli asker vardı.", "tr"),
|
||||||
|
("Bu 1. test", "Bu birinci test", "tr"),
|
||||||
|
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
|
||||||
|
# Hungarian
|
||||||
|
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
|
||||||
|
("50 katona volt.", "ötven katona volt.", "hu"),
|
||||||
|
("Ez az 1. teszt", "Ez az első teszt", "hu"),
|
||||||
|
# Korean
|
||||||
|
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
|
||||||
|
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
|
||||||
|
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
|
||||||
|
]
|
||||||
|
for a, b, lang in test_cases:
|
||||||
|
out = expand_numbers_multilingual(a, lang=lang)
|
||||||
|
assert out == b, f"'{out}' vs '{b}'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_abbreviations_multilingual():
|
||||||
|
test_cases = [
|
||||||
|
# English
|
||||||
|
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
|
||||||
|
("Dr. Jones is here.", "doctor Jones is here.", "en"),
|
||||||
|
# Spanish
|
||||||
|
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
|
||||||
|
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
|
||||||
|
# French
|
||||||
|
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
|
||||||
|
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
|
||||||
|
# German
|
||||||
|
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
|
||||||
|
# Portuguese
|
||||||
|
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
|
||||||
|
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
|
||||||
|
# Italian
|
||||||
|
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
|
||||||
|
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
|
||||||
|
# Polish
|
||||||
|
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
|
||||||
|
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
|
||||||
|
# Czech
|
||||||
|
("P. Novák", "pan Novák", "cs"),
|
||||||
|
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
|
||||||
|
# Dutch
|
||||||
|
("Dhr. Jansen", "de heer Jansen", "nl"),
|
||||||
|
("Mevr. de Vries", "mevrouw de Vries", "nl"),
|
||||||
|
# Russian
|
||||||
|
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
|
||||||
|
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
|
||||||
|
# Turkish
|
||||||
|
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
|
||||||
|
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
|
||||||
|
# Hungarian
|
||||||
|
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for a, b, lang in test_cases:
|
||||||
|
out = expand_abbreviations_multilingual(a, lang=lang)
|
||||||
|
assert out == b, f"'{out}' vs '{b}'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_symbols_multilingual():
|
||||||
|
test_cases = [
|
||||||
|
("I have 14% battery", "I have 14 percent battery", "en"),
|
||||||
|
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
|
||||||
|
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
|
||||||
|
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
|
||||||
|
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
|
||||||
|
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
|
||||||
|
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
|
||||||
|
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
|
||||||
|
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
|
||||||
|
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
|
||||||
|
("Я буду @ дома", "Я буду собака дома", "ru"),
|
||||||
|
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||||
|
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||||
|
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||||
|
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||||
|
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||||
|
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||||
|
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for a, b, lang in test_cases:
|
||||||
|
out = expand_symbols_multilingual(a, lang=lang)
|
||||||
|
assert out == b, f"'{out}' vs '{b}'"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_expand_numbers_multilingual()
|
||||||
|
test_abbreviations_multilingual()
|
||||||
|
test_symbols_multilingual()
|
||||||
|
|
|
@ -2,13 +2,11 @@ import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchaudio
|
|
||||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
from TTS.tts.models.xtts import load_audio
|
||||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
@ -50,31 +48,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
||||||
return rel_clip, rel_clip.shape[-1], cond_idxs
|
return rel_clip, rel_clip.shape[-1], cond_idxs
|
||||||
|
|
||||||
|
|
||||||
def load_audio(audiopath, sampling_rate):
|
|
||||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
|
||||||
if audiopath[-4:] == ".mp3":
|
|
||||||
# it uses torchaudio with sox backend to load mp3
|
|
||||||
audio, lsr = torchaudio_sox_load(audiopath)
|
|
||||||
else:
|
|
||||||
# it uses torchaudio soundfile backend to load all the others data type
|
|
||||||
audio, lsr = torchaudio_soundfile_load(audiopath)
|
|
||||||
|
|
||||||
# stereo to mono if needed
|
|
||||||
if audio.size(0) != 1:
|
|
||||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
if lsr != sampling_rate:
|
|
||||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
|
||||||
|
|
||||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
|
||||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
|
||||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
|
||||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
|
||||||
# clip audio invalid values
|
|
||||||
audio.clip_(-1, 1)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
class XTTSDataset(torch.utils.data.Dataset):
|
class XTTSDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -88,6 +61,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_wav_len = model_args.max_wav_length
|
self.max_wav_len = model_args.max_wav_length
|
||||||
self.max_text_len = model_args.max_text_length
|
self.max_text_len = model_args.max_text_length
|
||||||
|
self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
|
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
@ -109,7 +83,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
try:
|
try:
|
||||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||||
except:
|
except:
|
||||||
pass
|
continue
|
||||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
if (
|
if (
|
||||||
wav is None
|
wav is None
|
||||||
|
@ -140,10 +114,24 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
# Ultra short clips are also useless (and can cause problems within some models).
|
# Ultra short clips are also useless (and can cause problems within some models).
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
if self.use_masking_gt_prompt_approach:
|
||||||
# get a slice from GT to condition the model
|
# get a slice from GT to condition the model
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
cond, _, cond_idxs = get_prompt_slice(
|
||||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
)
|
)
|
||||||
|
# if use masking do not use cond_len
|
||||||
|
cond_len = torch.nan
|
||||||
|
else:
|
||||||
|
ref_sample = (
|
||||||
|
sample["reference_path"]
|
||||||
|
if "reference_path" in sample and sample["reference_path"] is not None
|
||||||
|
else audiopath
|
||||||
|
)
|
||||||
|
cond, cond_len, _ = get_prompt_slice(
|
||||||
|
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
|
)
|
||||||
|
# if do not use masking use cond_len
|
||||||
|
cond_idxs = torch.nan
|
||||||
|
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
|
@ -199,8 +187,10 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||||
"filenames": audiopath,
|
"filenames": audiopath,
|
||||||
"conditioning": cond.unsqueeze(1),
|
"conditioning": cond.unsqueeze(1),
|
||||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
"cond_lens": torch.tensor(cond_len, dtype=torch.long)
|
||||||
"cond_idxs": torch.tensor(cond_idxs),
|
if cond_len is not torch.nan
|
||||||
|
else torch.tensor([cond_len]),
|
||||||
|
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]),
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -221,6 +211,13 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||||
|
|
||||||
|
if torch.any(batch["cond_idxs"].isnan()):
|
||||||
|
batch["cond_idxs"] = None
|
||||||
|
|
||||||
|
if torch.any(batch["cond_lens"].isnan()):
|
||||||
|
batch["cond_lens"] = None
|
||||||
|
|
||||||
max_text_len = batch["text_lengths"].max()
|
max_text_len = batch["text_lengths"].max()
|
||||||
max_wav_len = batch["wav_lengths"].max()
|
max_wav_len = batch["wav_lengths"].max()
|
||||||
|
|
||||||
|
|
|
@ -141,6 +141,19 @@ class GPTTrainer(BaseTTS):
|
||||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||||
|
|
||||||
# Mel spectrogram extractor for conditioning
|
# Mel spectrogram extractor for conditioning
|
||||||
|
if self.args.gpt_use_perceiver_resampler:
|
||||||
|
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||||
|
filter_length=2048,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
normalize=False,
|
||||||
|
sampling_rate=config.audio.sample_rate,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
n_mel_channels=80,
|
||||||
|
mel_norm_file=self.args.mel_norm_file,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||||
filter_length=4096,
|
filter_length=4096,
|
||||||
hop_length=1024,
|
hop_length=1024,
|
||||||
|
@ -186,7 +199,7 @@ class GPTTrainer(BaseTTS):
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
|
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
(actuated by `text_first`).
|
(actuated by `text_first`).
|
||||||
|
@ -197,9 +210,16 @@ class GPTTrainer(BaseTTS):
|
||||||
wav_lengths: long tensor, (b,)
|
wav_lengths: long tensor, (b,)
|
||||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||||
cond_idxs: cond start and end indexs, (b, 2)
|
cond_idxs: cond start and end indexs, (b, 2)
|
||||||
|
cond_lens: long tensor, (b,)
|
||||||
"""
|
"""
|
||||||
losses = self.xtts.gpt(
|
losses = self.xtts.gpt(
|
||||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
text_inputs,
|
||||||
|
text_lengths,
|
||||||
|
audio_codes,
|
||||||
|
wav_lengths,
|
||||||
|
cond_mels=cond_mels,
|
||||||
|
cond_idxs=cond_idxs,
|
||||||
|
cond_lens=cond_lens,
|
||||||
)
|
)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
@ -213,7 +233,11 @@ class GPTTrainer(BaseTTS):
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
for idx, s_info in enumerate(self.config.test_sentences):
|
for idx, s_info in enumerate(self.config.test_sentences):
|
||||||
wav = self.xtts.synthesize(
|
wav = self.xtts.synthesize(
|
||||||
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
|
s_info["text"],
|
||||||
|
self.config,
|
||||||
|
s_info["speaker_wav"],
|
||||||
|
s_info["language"],
|
||||||
|
gpt_cond_len=3,
|
||||||
)["wav"]
|
)["wav"]
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
|
||||||
|
@ -269,7 +293,6 @@ class GPTTrainer(BaseTTS):
|
||||||
del batch["padded_text"]
|
del batch["padded_text"]
|
||||||
del batch["wav"]
|
del batch["wav"]
|
||||||
del batch["conditioning"]
|
del batch["conditioning"]
|
||||||
del batch["cond_lens"]
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def train_step(self, batch, criterion):
|
def train_step(self, batch, criterion):
|
||||||
|
@ -280,8 +303,11 @@ class GPTTrainer(BaseTTS):
|
||||||
audio_codes = batch["audio_codes"]
|
audio_codes = batch["audio_codes"]
|
||||||
wav_lengths = batch["wav_lengths"]
|
wav_lengths = batch["wav_lengths"]
|
||||||
cond_idxs = batch["cond_idxs"]
|
cond_idxs = batch["cond_idxs"]
|
||||||
|
cond_lens = batch["cond_lens"]
|
||||||
|
|
||||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
loss_text, loss_mel, _ = self.forward(
|
||||||
|
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens
|
||||||
|
)
|
||||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||||
|
@ -292,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
||||||
batch["cond_idxs"] = None
|
batch["cond_idxs"] = None
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
def on_train_epoch_start(self, trainer):
|
||||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
trainer.model.eval() # the whole model to eval
|
||||||
self.dvae = self.dvae.eval()
|
# put gpt model in training mode
|
||||||
|
trainer.model.xtts.gpt.train()
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
|
|
|
@ -1,385 +0,0 @@
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
|
||||||
|
|
||||||
|
|
||||||
class KernelPredictor(torch.nn.Module):
|
|
||||||
"""Kernel predictor for the location-variable convolutions"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cond_channels,
|
|
||||||
conv_in_channels,
|
|
||||||
conv_out_channels,
|
|
||||||
conv_layers,
|
|
||||||
conv_kernel_size=3,
|
|
||||||
kpnet_hidden_channels=64,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
kpnet_dropout=0.0,
|
|
||||||
kpnet_nonlinear_activation="LeakyReLU",
|
|
||||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cond_channels (int): number of channel for the conditioning sequence,
|
|
||||||
conv_in_channels (int): number of channel for the input sequence,
|
|
||||||
conv_out_channels (int): number of channel for the output sequence,
|
|
||||||
conv_layers (int): number of layers
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv_in_channels = conv_in_channels
|
|
||||||
self.conv_out_channels = conv_out_channels
|
|
||||||
self.conv_kernel_size = conv_kernel_size
|
|
||||||
self.conv_layers = conv_layers
|
|
||||||
|
|
||||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
|
||||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
|
||||||
|
|
||||||
self.input_conv = nn.Sequential(
|
|
||||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.residual_convs = nn.ModuleList()
|
|
||||||
padding = (kpnet_conv_size - 1) // 2
|
|
||||||
for _ in range(3):
|
|
||||||
self.residual_convs.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Dropout(kpnet_dropout),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.kernel_conv = nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_kernel_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.bias_conv = nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_bias_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, c):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
|
||||||
"""
|
|
||||||
batch, _, cond_length = c.shape
|
|
||||||
c = self.input_conv(c)
|
|
||||||
for residual_conv in self.residual_convs:
|
|
||||||
residual_conv.to(c.device)
|
|
||||||
c = c + residual_conv(c)
|
|
||||||
k = self.kernel_conv(c)
|
|
||||||
b = self.bias_conv(c)
|
|
||||||
kernels = k.contiguous().view(
|
|
||||||
batch,
|
|
||||||
self.conv_layers,
|
|
||||||
self.conv_in_channels,
|
|
||||||
self.conv_out_channels,
|
|
||||||
self.conv_kernel_size,
|
|
||||||
cond_length,
|
|
||||||
)
|
|
||||||
bias = b.contiguous().view(
|
|
||||||
batch,
|
|
||||||
self.conv_layers,
|
|
||||||
self.conv_out_channels,
|
|
||||||
cond_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
return kernels, bias
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
|
||||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
|
||||||
nn.utils.remove_weight_norm(self.bias_conv)
|
|
||||||
for block in self.residual_convs:
|
|
||||||
nn.utils.remove_weight_norm(block[1])
|
|
||||||
nn.utils.remove_weight_norm(block[3])
|
|
||||||
|
|
||||||
|
|
||||||
class LVCBlock(torch.nn.Module):
|
|
||||||
"""the location-variable convolutions"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
cond_channels,
|
|
||||||
stride,
|
|
||||||
dilations=[1, 3, 9, 27],
|
|
||||||
lReLU_slope=0.2,
|
|
||||||
conv_kernel_size=3,
|
|
||||||
cond_hop_length=256,
|
|
||||||
kpnet_hidden_channels=64,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
kpnet_dropout=0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.cond_hop_length = cond_hop_length
|
|
||||||
self.conv_layers = len(dilations)
|
|
||||||
self.conv_kernel_size = conv_kernel_size
|
|
||||||
|
|
||||||
self.kernel_predictor = KernelPredictor(
|
|
||||||
cond_channels=cond_channels,
|
|
||||||
conv_in_channels=in_channels,
|
|
||||||
conv_out_channels=2 * in_channels,
|
|
||||||
conv_layers=len(dilations),
|
|
||||||
conv_kernel_size=conv_kernel_size,
|
|
||||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size=kpnet_conv_size,
|
|
||||||
kpnet_dropout=kpnet_dropout,
|
|
||||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.convt_pre = nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.ConvTranspose1d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
2 * stride,
|
|
||||||
stride=stride,
|
|
||||||
padding=stride // 2 + stride % 2,
|
|
||||||
output_padding=stride % 2,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_blocks = nn.ModuleList()
|
|
||||||
for dilation in dilations:
|
|
||||||
self.conv_blocks.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
conv_kernel_size,
|
|
||||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
|
||||||
dilation=dilation,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
"""forward propagation of the location-variable convolutions.
|
|
||||||
Args:
|
|
||||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
|
||||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: the output sequence (batch, in_channels, in_length)
|
|
||||||
"""
|
|
||||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
|
||||||
|
|
||||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
|
||||||
kernels, bias = self.kernel_predictor(c)
|
|
||||||
|
|
||||||
for i, conv in enumerate(self.conv_blocks):
|
|
||||||
output = conv(x) # (B, c_g, stride * L')
|
|
||||||
|
|
||||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
|
||||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
|
||||||
|
|
||||||
output = self.location_variable_convolution(
|
|
||||||
output, k, b, hop_size=self.cond_hop_length
|
|
||||||
) # (B, 2 * c_g, stride * L'): LVC
|
|
||||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
|
||||||
output[:, in_channels:, :]
|
|
||||||
) # (B, c_g, stride * L'): GAU
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
|
||||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
|
||||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
|
||||||
Args:
|
|
||||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
|
||||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
|
||||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
|
||||||
dilation (int): the dilation of convolution.
|
|
||||||
hop_size (int): the hop_size of the conditioning sequence.
|
|
||||||
Returns:
|
|
||||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
|
||||||
"""
|
|
||||||
batch, _, in_length = x.shape
|
|
||||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
|
||||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
|
||||||
|
|
||||||
padding = dilation * int((kernel_size - 1) / 2)
|
|
||||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
|
||||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
|
||||||
|
|
||||||
if hop_size < dilation:
|
|
||||||
x = F.pad(x, (0, dilation), "constant", 0)
|
|
||||||
x = x.unfold(
|
|
||||||
3, dilation, dilation
|
|
||||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
|
||||||
x = x[:, :, :, :, :hop_size]
|
|
||||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
|
||||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
|
||||||
|
|
||||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
|
||||||
o = o.to(memory_format=torch.channels_last_3d)
|
|
||||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
|
||||||
o = o + bias
|
|
||||||
o = o.contiguous().view(batch, out_channels, -1)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
self.kernel_predictor.remove_weight_norm()
|
|
||||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
|
||||||
for block in self.conv_blocks:
|
|
||||||
nn.utils.remove_weight_norm(block[1])
|
|
||||||
|
|
||||||
|
|
||||||
class UnivNetGenerator(nn.Module):
|
|
||||||
"""
|
|
||||||
UnivNet Generator
|
|
||||||
|
|
||||||
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
noise_dim=64,
|
|
||||||
channel_size=32,
|
|
||||||
dilations=[1, 3, 9, 27],
|
|
||||||
strides=[8, 8, 4],
|
|
||||||
lReLU_slope=0.2,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
# Below are MEL configurations options that this generator requires.
|
|
||||||
hop_length=256,
|
|
||||||
n_mel_channels=100,
|
|
||||||
):
|
|
||||||
super(UnivNetGenerator, self).__init__()
|
|
||||||
self.mel_channel = n_mel_channels
|
|
||||||
self.noise_dim = noise_dim
|
|
||||||
self.hop_length = hop_length
|
|
||||||
channel_size = channel_size
|
|
||||||
kpnet_conv_size = kpnet_conv_size
|
|
||||||
|
|
||||||
self.res_stack = nn.ModuleList()
|
|
||||||
hop_length = 1
|
|
||||||
for stride in strides:
|
|
||||||
hop_length = stride * hop_length
|
|
||||||
self.res_stack.append(
|
|
||||||
LVCBlock(
|
|
||||||
channel_size,
|
|
||||||
n_mel_channels,
|
|
||||||
stride=stride,
|
|
||||||
dilations=dilations,
|
|
||||||
lReLU_slope=lReLU_slope,
|
|
||||||
cond_hop_length=hop_length,
|
|
||||||
kpnet_conv_size=kpnet_conv_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
|
||||||
|
|
||||||
self.conv_post = nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, c, z):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
|
||||||
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
|
||||||
|
|
||||||
"""
|
|
||||||
z = self.conv_pre(z) # (B, c_g, L)
|
|
||||||
|
|
||||||
for res_block in self.res_stack:
|
|
||||||
res_block.to(z.device)
|
|
||||||
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
|
||||||
|
|
||||||
z = self.conv_post(z) # (B, 1, L * 256)
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
def eval(self, inference=False):
|
|
||||||
super(UnivNetGenerator, self).eval()
|
|
||||||
# don't remove weight norm while validation in training loop
|
|
||||||
if inference:
|
|
||||||
self.remove_weight_norm()
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
|
||||||
|
|
||||||
for layer in self.conv_post:
|
|
||||||
if len(layer.state_dict()) != 0:
|
|
||||||
nn.utils.remove_weight_norm(layer)
|
|
||||||
|
|
||||||
for res_block in self.res_stack:
|
|
||||||
res_block.remove_weight_norm()
|
|
||||||
|
|
||||||
def inference(self, c, z=None):
|
|
||||||
# pad input mel with zeros to cut artifact
|
|
||||||
# see https://github.com/seungwonpark/melgan/issues/8
|
|
||||||
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
|
||||||
mel = torch.cat((c, zero), dim=2)
|
|
||||||
|
|
||||||
if z is None:
|
|
||||||
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
|
||||||
|
|
||||||
audio = self.forward(mel, z)
|
|
||||||
audio = audio[:, :, : -(self.hop_length * 10)]
|
|
||||||
audio = audio.clamp(min=-1, max=1)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = UnivNetGenerator()
|
|
||||||
|
|
||||||
c = torch.randn(3, 100, 10)
|
|
||||||
z = torch.randn(3, 64, 10)
|
|
||||||
print(c.shape)
|
|
||||||
|
|
||||||
y = model(c, z)
|
|
||||||
print(y.shape)
|
|
||||||
assert y.shape == torch.Size([3, 1, 2560])
|
|
||||||
|
|
||||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print(pytorch_total_params)
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,21 +1,16 @@
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import librosa
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
|
||||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
|
||||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -23,7 +18,19 @@ init_stream_support()
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel_cloning(
|
def wav_to_mel_cloning(
|
||||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
wav,
|
||||||
|
mel_norms_file="../experiments/clips_mel_norms.pth",
|
||||||
|
mel_norms=None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
n_fft=4096,
|
||||||
|
hop_length=1024,
|
||||||
|
win_length=4096,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
||||||
|
@ -38,15 +45,15 @@ def wav_to_mel_cloning(
|
||||||
torch.Tensor: Mel-spectrogram tensor.
|
torch.Tensor: Mel-spectrogram tensor.
|
||||||
"""
|
"""
|
||||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||||
n_fft=4096,
|
n_fft=n_fft,
|
||||||
hop_length=1024,
|
hop_length=hop_length,
|
||||||
win_length=4096,
|
win_length=win_length,
|
||||||
power=2,
|
power=power,
|
||||||
normalized=False,
|
normalized=normalized,
|
||||||
sample_rate=22050,
|
sample_rate=sample_rate,
|
||||||
f_min=0,
|
f_min=f_min,
|
||||||
f_max=8000,
|
f_max=f_max,
|
||||||
n_mels=80,
|
n_mels=n_mels,
|
||||||
norm="slaney",
|
norm="slaney",
|
||||||
).to(device)
|
).to(device)
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
@ -58,6 +65,28 @@ def wav_to_mel_cloning(
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(audiopath, sampling_rate):
|
||||||
|
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||||
|
|
||||||
|
# torchaudio should chose proper backend to load audio depending on platform
|
||||||
|
audio, lsr = torchaudio.load(audiopath)
|
||||||
|
|
||||||
|
# stereo to mono if needed
|
||||||
|
if audio.size(0) != 1:
|
||||||
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if lsr != sampling_rate:
|
||||||
|
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||||
|
|
||||||
|
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||||
|
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||||
|
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||||
|
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||||
|
# clip audio invalid values
|
||||||
|
audio.clip_(-1, 1)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
def pad_or_truncate(t, length):
|
def pad_or_truncate(t, length):
|
||||||
"""
|
"""
|
||||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||||
|
@ -77,78 +106,6 @@ def pad_or_truncate(t, length):
|
||||||
return tp
|
return tp
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(
|
|
||||||
trained_diffusion_steps=4000,
|
|
||||||
desired_diffusion_steps=200,
|
|
||||||
cond_free=True,
|
|
||||||
cond_free_k=1,
|
|
||||||
sampler="ddim",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Load a GaussianDiffusion instance configured for use as a decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trained_diffusion_steps (int): The number of diffusion steps used during training.
|
|
||||||
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
|
|
||||||
cond_free (bool): Whether to use a conditioning-free model.
|
|
||||||
cond_free_k (int): The number of samples to use for conditioning-free models.
|
|
||||||
sampler (str): The name of the sampler to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A SpacedDiffusion instance configured with the given parameters.
|
|
||||||
"""
|
|
||||||
return SpacedDiffusion(
|
|
||||||
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
|
||||||
model_mean_type="epsilon",
|
|
||||||
model_var_type="learned_range",
|
|
||||||
loss_type="mse",
|
|
||||||
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
|
|
||||||
conditioning_free=cond_free,
|
|
||||||
conditioning_free_k=cond_free_k,
|
|
||||||
sampler=sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def do_spectrogram_diffusion(
|
|
||||||
diffusion_model,
|
|
||||||
diffuser,
|
|
||||||
latents,
|
|
||||||
conditioning_latents,
|
|
||||||
temperature=1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate a mel-spectrogram using a diffusion model and a diffuser.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
|
||||||
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
|
|
||||||
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
|
|
||||||
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
|
|
||||||
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
output_seq_len = (
|
|
||||||
latents.shape[1] * 4 * 24000 // 22050
|
|
||||||
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
|
||||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
|
||||||
precomputed_embeddings = diffusion_model.timestep_independent(
|
|
||||||
latents, conditioning_latents, output_seq_len, False
|
|
||||||
)
|
|
||||||
|
|
||||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
|
||||||
mel = diffuser.sample_loop(
|
|
||||||
diffusion_model,
|
|
||||||
output_shape,
|
|
||||||
noise=noise,
|
|
||||||
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
|
||||||
progress=False,
|
|
||||||
)
|
|
||||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XttsAudioConfig(Coqpit):
|
class XttsAudioConfig(Coqpit):
|
||||||
"""
|
"""
|
||||||
|
@ -156,12 +113,10 @@ class XttsAudioConfig(Coqpit):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_rate (int): The sample rate in which the GPT operates.
|
sample_rate (int): The sample rate in which the GPT operates.
|
||||||
diffusion_sample_rate (int): The sample rate of the diffusion audio waveform.
|
|
||||||
output_sample_rate (int): The sample rate of the output audio waveform.
|
output_sample_rate (int): The sample rate of the output audio waveform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample_rate: int = 22050
|
sample_rate: int = 22050
|
||||||
diffusion_sample_rate: int = 24000
|
|
||||||
output_sample_rate: int = 24000
|
output_sample_rate: int = 24000
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,32 +132,21 @@ class XttsArgs(Coqpit):
|
||||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||||
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
|
||||||
|
|
||||||
For GPT model:
|
For GPT model:
|
||||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||||
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
||||||
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
||||||
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
||||||
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
||||||
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
||||||
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
||||||
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
||||||
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
||||||
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
||||||
|
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
|
||||||
For DiffTTS model:
|
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
|
||||||
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
|
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
|
||||||
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
|
|
||||||
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
|
|
||||||
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
|
|
||||||
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
|
|
||||||
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
|
|
||||||
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
|
|
||||||
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
|
|
||||||
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
|
|
||||||
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
|
|
||||||
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gpt_batch_size: int = 1
|
gpt_batch_size: int = 1
|
||||||
|
@ -212,8 +156,6 @@ class XttsArgs(Coqpit):
|
||||||
clvp_checkpoint: str = None
|
clvp_checkpoint: str = None
|
||||||
decoder_checkpoint: str = None
|
decoder_checkpoint: str = None
|
||||||
num_chars: int = 255
|
num_chars: int = 255
|
||||||
use_hifigan: bool = True
|
|
||||||
use_ne_hifigan: bool = False
|
|
||||||
|
|
||||||
# XTTS GPT Encoder params
|
# XTTS GPT Encoder params
|
||||||
tokenizer_file: str = ""
|
tokenizer_file: str = ""
|
||||||
|
@ -229,25 +171,14 @@ class XttsArgs(Coqpit):
|
||||||
gpt_num_audio_tokens: int = 8194
|
gpt_num_audio_tokens: int = 8194
|
||||||
gpt_start_audio_token: int = 8192
|
gpt_start_audio_token: int = 8192
|
||||||
gpt_stop_audio_token: int = 8193
|
gpt_stop_audio_token: int = 8193
|
||||||
|
gpt_code_stride_len: int = 1024
|
||||||
# Diffusion Decoder params
|
gpt_use_masking_gt_prompt_approach: bool = True
|
||||||
diff_model_channels: int = 1024
|
gpt_use_perceiver_resampler: bool = False
|
||||||
diff_num_layers: int = 10
|
|
||||||
diff_in_channels: int = 100
|
|
||||||
diff_out_channels: int = 200
|
|
||||||
diff_in_latent_channels: int = 1024
|
|
||||||
diff_in_tokens: int = 8193
|
|
||||||
diff_dropout: int = 0
|
|
||||||
diff_use_fp16: bool = False
|
|
||||||
diff_num_heads: int = 16
|
|
||||||
diff_layer_drop: int = 0
|
|
||||||
diff_unconditioned_percentage: int = 0
|
|
||||||
|
|
||||||
# HifiGAN Decoder params
|
# HifiGAN Decoder params
|
||||||
input_sample_rate: int = 22050
|
input_sample_rate: int = 22050
|
||||||
output_sample_rate: int = 24000
|
output_sample_rate: int = 24000
|
||||||
output_hop_length: int = 256
|
output_hop_length: int = 256
|
||||||
ar_mel_length_compression: int = 1024
|
|
||||||
decoder_input_dim: int = 1024
|
decoder_input_dim: int = 1024
|
||||||
d_vector_dim: int = 512
|
d_vector_dim: int = 512
|
||||||
cond_d_vector_in_each_upsampling_layer: bool = True
|
cond_d_vector_in_each_upsampling_layer: bool = True
|
||||||
|
@ -304,119 +235,143 @@ class Xtts(BaseTTS):
|
||||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||||
start_audio_token=self.args.gpt_start_audio_token,
|
start_audio_token=self.args.gpt_start_audio_token,
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||||
|
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
||||||
|
code_stride_len=self.args.gpt_code_stride_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_hifigan:
|
|
||||||
self.hifigan_decoder = HifiDecoder(
|
self.hifigan_decoder = HifiDecoder(
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
output_sample_rate=self.args.output_sample_rate,
|
||||||
output_hop_length=self.args.output_hop_length,
|
output_hop_length=self.args.output_hop_length,
|
||||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
decoder_input_dim=self.args.decoder_input_dim,
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
d_vector_dim=self.args.d_vector_dim,
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_ne_hifigan:
|
|
||||||
self.ne_hifigan_decoder = HifiDecoder(
|
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
|
||||||
output_hop_length=self.args.output_hop_length,
|
|
||||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
|
|
||||||
self.diffusion_decoder = DiffusionTts(
|
|
||||||
model_channels=self.args.diff_model_channels,
|
|
||||||
num_layers=self.args.diff_num_layers,
|
|
||||||
in_channels=self.args.diff_in_channels,
|
|
||||||
out_channels=self.args.diff_out_channels,
|
|
||||||
in_latent_channels=self.args.diff_in_latent_channels,
|
|
||||||
in_tokens=self.args.diff_in_tokens,
|
|
||||||
dropout=self.args.diff_dropout,
|
|
||||||
use_fp16=self.args.diff_use_fp16,
|
|
||||||
num_heads=self.args.diff_num_heads,
|
|
||||||
layer_drop=self.args.diff_layer_drop,
|
|
||||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
|
||||||
)
|
|
||||||
self.vocoder = UnivNetGenerator()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
|
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
|
||||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_path (str): Path to the audio file.
|
audio (tensor): audio tensor.
|
||||||
length (int): Length of the audio in seconds. Defaults to 3.
|
sr (int): Sample rate of the audio.
|
||||||
|
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
|
||||||
|
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
|
||||||
|
is being used without chunking. It must be < `length`. Defaults to 6.
|
||||||
"""
|
"""
|
||||||
|
if sr != 22050:
|
||||||
|
audio = torchaudio.functional.resample(audio, sr, 22050)
|
||||||
|
if length > 0:
|
||||||
|
audio = audio[:, : 22050 * length]
|
||||||
|
if self.args.gpt_use_perceiver_resampler:
|
||||||
|
style_embs = []
|
||||||
|
for i in range(0, audio.shape[1], 22050 * chunk_length):
|
||||||
|
audio_chunk = audio[:, i : i + 22050 * chunk_length]
|
||||||
|
mel_chunk = wav_to_mel_cloning(
|
||||||
|
audio_chunk,
|
||||||
|
mel_norms=self.mel_stats.cpu(),
|
||||||
|
n_fft=2048,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80,
|
||||||
|
)
|
||||||
|
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
|
||||||
|
style_embs.append(style_emb)
|
||||||
|
|
||||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
# mean style embedding
|
||||||
audio_22k = audio_22k[:, : 22050 * length]
|
cond_latent = torch.stack(style_embs).mean(dim=0)
|
||||||
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
else:
|
||||||
|
mel = wav_to_mel_cloning(
|
||||||
|
audio,
|
||||||
|
mel_norms=self.mel_stats.cpu(),
|
||||||
|
n_fft=4096,
|
||||||
|
hop_length=1024,
|
||||||
|
win_length=4096,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80,
|
||||||
|
)
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def get_diffusion_cond_latents(self, audio, sr):
|
|
||||||
from math import ceil
|
|
||||||
|
|
||||||
diffusion_conds = []
|
|
||||||
CHUNK_SIZE = 102400
|
|
||||||
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
|
|
||||||
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
|
|
||||||
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
|
||||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
|
||||||
cond_mel = wav_to_univnet_mel(
|
|
||||||
current_sample.to(self.device),
|
|
||||||
do_normalization=False,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
diffusion_conds.append(cond_mel)
|
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
|
||||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
|
||||||
return diffusion_latent
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(self, audio, sr):
|
def get_speaker_embedding(self, audio, sr):
|
||||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||||
return self.hifigan_decoder.speaker_encoder.forward(
|
return (
|
||||||
audio_16k.to(self.device), l2_norm=True
|
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
|
||||||
).unsqueeze(-1).to(self.device)
|
.unsqueeze(-1)
|
||||||
|
.to(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
self,
|
self,
|
||||||
audio_path,
|
audio_path,
|
||||||
|
max_ref_length=30,
|
||||||
gpt_cond_len=6,
|
gpt_cond_len=6,
|
||||||
max_ref_length=10,
|
gpt_cond_chunk_len=6,
|
||||||
librosa_trim_db=None,
|
librosa_trim_db=None,
|
||||||
sound_norm_refs=False,
|
sound_norm_refs=False,
|
||||||
|
load_sr=22050,
|
||||||
):
|
):
|
||||||
speaker_embedding = None
|
"""Get the conditioning latents for the GPT model from the given audio.
|
||||||
diffusion_cond_latents = None
|
|
||||||
|
|
||||||
audio, sr = torchaudio.load(audio_path)
|
Args:
|
||||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
audio_path (str or List[str]): Path to reference audio file(s).
|
||||||
if audio.shape[0] > 1:
|
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
|
||||||
audio = audio.mean(0, keepdim=True)
|
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
|
||||||
|
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
|
||||||
|
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
|
||||||
|
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
|
||||||
|
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
|
||||||
|
"""
|
||||||
|
# deal with multiples references
|
||||||
|
if not isinstance(audio_path, list):
|
||||||
|
audio_paths = [audio_path]
|
||||||
|
else:
|
||||||
|
audio_paths = audio_path
|
||||||
|
|
||||||
|
speaker_embeddings = []
|
||||||
|
audios = []
|
||||||
|
speaker_embedding = None
|
||||||
|
for file_path in audio_paths:
|
||||||
|
audio = load_audio(file_path, load_sr)
|
||||||
|
audio = audio[:, : load_sr * max_ref_length].to(self.device)
|
||||||
if sound_norm_refs:
|
if sound_norm_refs:
|
||||||
audio = (audio / torch.abs(audio).max()) * 0.75
|
audio = (audio / torch.abs(audio).max()) * 0.75
|
||||||
if librosa_trim_db is not None:
|
if librosa_trim_db is not None:
|
||||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||||
|
|
||||||
if self.args.use_hifigan or self.args.use_ne_hifigan:
|
# compute latents for the decoder
|
||||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
|
||||||
else:
|
speaker_embeddings.append(speaker_embedding)
|
||||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
|
||||||
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
audios.append(audio)
|
||||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
|
||||||
|
# merge all the audios and compute the latents for the gpt
|
||||||
|
full_audio = torch.cat(audios, dim=-1)
|
||||||
|
gpt_cond_latents = self.get_gpt_cond_latents(
|
||||||
|
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
|
||||||
|
) # [1, 1024, T]
|
||||||
|
|
||||||
|
if speaker_embeddings:
|
||||||
|
speaker_embedding = torch.stack(speaker_embeddings)
|
||||||
|
speaker_embedding = speaker_embedding.mean(dim=0)
|
||||||
|
|
||||||
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
@ -424,7 +379,7 @@ class Xtts(BaseTTS):
|
||||||
Args:
|
Args:
|
||||||
text (str): Input text.
|
text (str): Input text.
|
||||||
config (XttsConfig): Config with inference parameters.
|
config (XttsConfig): Config with inference parameters.
|
||||||
speaker_wav (str): Path to the speaker audio file for cloning.
|
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
|
||||||
language (str): Language ID of the speaker.
|
language (str): Language ID of the speaker.
|
||||||
**kwargs: Inference settings. See `inference()`.
|
**kwargs: Inference settings. See `inference()`.
|
||||||
|
|
||||||
|
@ -434,11 +389,6 @@ class Xtts(BaseTTS):
|
||||||
as latents used at inference.
|
as latents used at inference.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Make the synthesizer happy 🥳
|
|
||||||
if isinstance(speaker_wav, list):
|
|
||||||
speaker_wav = speaker_wav[0]
|
|
||||||
|
|
||||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
||||||
|
|
||||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
||||||
|
@ -446,7 +396,7 @@ class Xtts(BaseTTS):
|
||||||
inference with config
|
inference with config
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
language in self.config.languages
|
"zh-cn" if language == "zh" else language in self.config.languages
|
||||||
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
||||||
# Use generally found best tuning knobs for generation.
|
# Use generally found best tuning knobs for generation.
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -455,10 +405,10 @@ class Xtts(BaseTTS):
|
||||||
"repetition_penalty": config.repetition_penalty,
|
"repetition_penalty": config.repetition_penalty,
|
||||||
"top_k": config.top_k,
|
"top_k": config.top_k,
|
||||||
"top_p": config.top_p,
|
"top_p": config.top_p,
|
||||||
"cond_free_k": config.cond_free_k,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"diffusion_temperature": config.diffusion_temperature,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"decoder_iterations": config.decoder_iterations,
|
"max_ref_len": config.max_ref_len,
|
||||||
"decoder_sampler": config.decoder_sampler,
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
}
|
}
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||||
|
@ -470,20 +420,17 @@ class Xtts(BaseTTS):
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
language,
|
language,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
gpt_cond_len=6,
|
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
# Decoder inference
|
# Cloning
|
||||||
decoder_iterations=100,
|
gpt_cond_len=30,
|
||||||
cond_free=True,
|
gpt_cond_chunk_len=6,
|
||||||
cond_free_k=2,
|
max_ref_len=10,
|
||||||
diffusion_temperature=1.0,
|
sound_norm_refs=False,
|
||||||
decoder_sampler="ddim",
|
|
||||||
decoder="hifigan",
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -512,28 +459,10 @@ class Xtts(BaseTTS):
|
||||||
(aka boring) outputs. Defaults to 0.8.
|
(aka boring) outputs. Defaults to 0.8.
|
||||||
|
|
||||||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||||
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
|
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
|
||||||
|
|
||||||
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
|
||||||
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
|
||||||
Generally a value above 250 is not noticeably better, however. Defaults to 100.
|
|
||||||
|
|
||||||
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
|
|
||||||
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
|
|
||||||
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
|
|
||||||
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
|
|
||||||
Defaults to True.
|
|
||||||
|
|
||||||
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
|
|
||||||
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
|
|
||||||
conditioning-free signal. Defaults to 2.0.
|
|
||||||
|
|
||||||
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
|
|
||||||
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
|
|
||||||
Defaults to 1.0.
|
|
||||||
|
|
||||||
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
|
|
||||||
Defaults to hifigan
|
|
||||||
|
|
||||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||||
|
@ -543,27 +472,25 @@ class Xtts(BaseTTS):
|
||||||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
|
||||||
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
|
audio_path=ref_audio_path,
|
||||||
|
gpt_cond_len=gpt_cond_len,
|
||||||
|
gpt_cond_chunk_len=gpt_cond_chunk_len,
|
||||||
|
max_ref_length=max_ref_len,
|
||||||
|
sound_norm_refs=sound_norm_refs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.inference(
|
return self.inference(
|
||||||
text,
|
text,
|
||||||
language,
|
language,
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
decoder_iterations=decoder_iterations,
|
|
||||||
cond_free=cond_free,
|
|
||||||
cond_free_k=cond_free_k,
|
|
||||||
diffusion_temperature=diffusion_temperature,
|
|
||||||
decoder_sampler=decoder_sampler,
|
|
||||||
decoder=decoder,
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -574,38 +501,35 @@ class Xtts(BaseTTS):
|
||||||
language,
|
language,
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
# Decoder inference
|
num_beams=1,
|
||||||
decoder_iterations=100,
|
speed=1.0,
|
||||||
cond_free=True,
|
enable_text_splitting=False,
|
||||||
cond_free_k=2,
|
|
||||||
diffusion_temperature=1.0,
|
|
||||||
decoder_sampler="ddim",
|
|
||||||
decoder="hifigan",
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
text = text.strip().lower()
|
language = language.split("-")[0] # remove the country code
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
|
if enable_text_splitting:
|
||||||
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
gpt_latents_list = []
|
||||||
|
for sent in text:
|
||||||
|
sent = sent.strip().lower()
|
||||||
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
if not self.args.use_hifigan:
|
|
||||||
diffuser = load_discrete_vocoder_diffuser(
|
|
||||||
desired_diffusion_steps=decoder_iterations,
|
|
||||||
cond_free=cond_free,
|
|
||||||
cond_free_k=cond_free_k,
|
|
||||||
sampler=decoder_sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gpt_codes = self.gpt.generate(
|
gpt_codes = self.gpt.generate(
|
||||||
cond_latents=gpt_cond_latent,
|
cond_latents=gpt_cond_latent,
|
||||||
|
@ -616,6 +540,7 @@ class Xtts(BaseTTS):
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
num_return_sequences=self.gpt_batch_size,
|
num_return_sequences=self.gpt_batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
|
@ -635,35 +560,20 @@ class Xtts(BaseTTS):
|
||||||
return_attentions=False,
|
return_attentions=False,
|
||||||
return_latent=True,
|
return_latent=True,
|
||||||
)
|
)
|
||||||
silence_token = 83
|
|
||||||
ctokens = 0
|
|
||||||
for k in range(gpt_codes.shape[-1]):
|
|
||||||
if gpt_codes[0, k] == silence_token:
|
|
||||||
ctokens += 1
|
|
||||||
else:
|
|
||||||
ctokens = 0
|
|
||||||
if ctokens > 8:
|
|
||||||
gpt_latents = gpt_latents[:, :k]
|
|
||||||
break
|
|
||||||
|
|
||||||
if decoder == "hifigan":
|
if length_scale != 1.0:
|
||||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
gpt_latents = F.interpolate(
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
elif decoder == "ne_hifigan":
|
).transpose(1, 2)
|
||||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
|
||||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
|
||||||
else:
|
|
||||||
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
|
||||||
mel = do_spectrogram_diffusion(
|
|
||||||
self.diffusion_decoder,
|
|
||||||
diffuser,
|
|
||||||
gpt_latents,
|
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=diffusion_temperature,
|
|
||||||
)
|
|
||||||
wav = self.vocoder.inference(mel)
|
|
||||||
|
|
||||||
return {"wav": wav.cpu().numpy().squeeze()}
|
gpt_latents_list.append(gpt_latents.cpu())
|
||||||
|
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||||
|
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
||||||
|
"speaker_embedding": speaker_embedding,
|
||||||
|
}
|
||||||
|
|
||||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||||
"""Handle chunk formatting in streaming mode"""
|
"""Handle chunk formatting in streaming mode"""
|
||||||
|
@ -671,10 +581,21 @@ class Xtts(BaseTTS):
|
||||||
if wav_gen_prev is not None:
|
if wav_gen_prev is not None:
|
||||||
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
|
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
|
||||||
if wav_overlap is not None:
|
if wav_overlap is not None:
|
||||||
|
# cross fade the overlap section
|
||||||
|
if overlap_len > len(wav_chunk):
|
||||||
|
# wav_chunk is smaller than overlap_len, pass on last wav_gen
|
||||||
|
if wav_gen_prev is not None:
|
||||||
|
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
|
||||||
|
else:
|
||||||
|
# not expecting will hit here as problem happens on last chunk
|
||||||
|
wav_chunk = wav_gen[-overlap_len:]
|
||||||
|
return wav_chunk, wav_gen, None
|
||||||
|
else:
|
||||||
crossfade_wav = wav_chunk[:overlap_len]
|
crossfade_wav = wav_chunk[:overlap_len]
|
||||||
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
|
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
|
||||||
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
|
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
|
||||||
wav_chunk[:overlap_len] += crossfade_wav
|
wav_chunk[:overlap_len] += crossfade_wav
|
||||||
|
|
||||||
wav_overlap = wav_gen[-overlap_len:]
|
wav_overlap = wav_gen[-overlap_len:]
|
||||||
wav_gen_prev = wav_gen
|
wav_gen_prev = wav_gen
|
||||||
return wav_chunk, wav_gen_prev, wav_overlap
|
return wav_chunk, wav_gen_prev, wav_overlap
|
||||||
|
@ -690,21 +611,30 @@ class Xtts(BaseTTS):
|
||||||
stream_chunk_size=20,
|
stream_chunk_size=20,
|
||||||
overlap_wav_len=1024,
|
overlap_wav_len=1024,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
# Decoder inference
|
speed=1.0,
|
||||||
decoder="hifigan",
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
assert hasattr(
|
language = language.split("-")[0] # remove the country code
|
||||||
self, "hifigan_decoder"
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
if enable_text_splitting:
|
||||||
text = text.strip().lower()
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
for sent in text:
|
||||||
|
sent = sent.strip().lower()
|
||||||
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
fake_inputs = self.gpt.compute_embeddings(
|
fake_inputs = self.gpt.compute_embeddings(
|
||||||
gpt_cond_latent.to(self.device),
|
gpt_cond_latent.to(self.device),
|
||||||
|
@ -741,14 +671,11 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
if decoder == "hifigan":
|
if length_scale != 1.0:
|
||||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
elif decoder == "ne_hifigan":
|
|
||||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
|
||||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
)
|
)
|
||||||
|
@ -756,10 +683,14 @@ class Xtts(BaseTTS):
|
||||||
yield wav_chunk
|
yield wav_chunk
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
||||||
def eval_step(self):
|
def eval_step(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
||||||
|
@ -772,11 +703,8 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def get_compatible_checkpoint_state_dict(self, model_path):
|
def get_compatible_checkpoint_state_dict(self, model_path):
|
||||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
|
||||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
|
||||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
|
||||||
# remove xtts gpt trainer extra keys
|
# remove xtts gpt trainer extra keys
|
||||||
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
||||||
for key in list(checkpoint.keys()):
|
for key in list(checkpoint.keys()):
|
||||||
# check if it is from the coqui Trainer if so convert it
|
# check if it is from the coqui Trainer if so convert it
|
||||||
if key.startswith("xtts."):
|
if key.startswith("xtts."):
|
||||||
|
@ -835,12 +763,11 @@ class Xtts(BaseTTS):
|
||||||
self.load_state_dict(checkpoint, strict=strict)
|
self.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
self.hifigan_decoder.eval()
|
||||||
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
|
|
||||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
|
||||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
|
||||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||||
self.gpt.eval()
|
self.gpt.eval()
|
||||||
|
|
||||||
def train_step(self):
|
def train_step(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
|
@ -201,7 +201,6 @@ def stft(
|
||||||
def istft(
|
def istft(
|
||||||
*,
|
*,
|
||||||
y: np.ndarray = None,
|
y: np.ndarray = None,
|
||||||
fft_size: int = None,
|
|
||||||
hop_length: int = None,
|
hop_length: int = None,
|
||||||
win_length: int = None,
|
win_length: int = None,
|
||||||
window: str = "hann",
|
window: str = "hann",
|
||||||
|
|
|
@ -5,10 +5,26 @@ import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile
|
import scipy.io.wavfile
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import StandardScaler
|
from TTS.tts.utils.helpers import StandardScaler
|
||||||
from TTS.utils.audio.numpy_transforms import compute_f0
|
from TTS.utils.audio.numpy_transforms import (
|
||||||
|
amp_to_db,
|
||||||
|
build_mel_basis,
|
||||||
|
compute_f0,
|
||||||
|
db_to_amp,
|
||||||
|
deemphasis,
|
||||||
|
find_endpoint,
|
||||||
|
griffin_lim,
|
||||||
|
load_wav,
|
||||||
|
mel_to_spec,
|
||||||
|
millisec_to_length,
|
||||||
|
preemphasis,
|
||||||
|
rms_volume_norm,
|
||||||
|
spec_to_mel,
|
||||||
|
stft,
|
||||||
|
trim_silence,
|
||||||
|
volume_norm,
|
||||||
|
)
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
|
@ -200,7 +216,9 @@ class AudioProcessor(object):
|
||||||
# setup stft parameters
|
# setup stft parameters
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
# compute stft parameters from given time values
|
# compute stft parameters from given time values
|
||||||
self.hop_length, self.win_length = self._stft_parameters()
|
self.win_length, self.hop_length = millisec_to_length(
|
||||||
|
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# use stft parameters from config file
|
# use stft parameters from config file
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
|
@ -215,8 +233,13 @@ class AudioProcessor(object):
|
||||||
for key, value in members.items():
|
for key, value in members.items():
|
||||||
print(" | > {}:{}".format(key, value))
|
print(" | > {}:{}".format(key, value))
|
||||||
# create spectrogram utils
|
# create spectrogram utils
|
||||||
self.mel_basis = self._build_mel_basis()
|
self.mel_basis = build_mel_basis(
|
||||||
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
|
sample_rate=self.sample_rate,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
num_mels=self.num_mels,
|
||||||
|
mel_fmax=self.mel_fmax,
|
||||||
|
mel_fmin=self.mel_fmin,
|
||||||
|
)
|
||||||
# setup scaler
|
# setup scaler
|
||||||
if stats_path and signal_norm:
|
if stats_path and signal_norm:
|
||||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||||
|
@ -232,35 +255,6 @@ class AudioProcessor(object):
|
||||||
return AudioProcessor(verbose=verbose, **config.audio)
|
return AudioProcessor(verbose=verbose, **config.audio)
|
||||||
return AudioProcessor(verbose=verbose, **config)
|
return AudioProcessor(verbose=verbose, **config)
|
||||||
|
|
||||||
### setting up the parameters ###
|
|
||||||
def _build_mel_basis(
|
|
||||||
self,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Build melspectrogram basis.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: melspectrogram basis.
|
|
||||||
"""
|
|
||||||
if self.mel_fmax is not None:
|
|
||||||
assert self.mel_fmax <= self.sample_rate // 2
|
|
||||||
return librosa.filters.mel(
|
|
||||||
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
|
||||||
)
|
|
||||||
|
|
||||||
def _stft_parameters(
|
|
||||||
self,
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""Compute the real STFT parameters from the time values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: hop length and window length for STFT.
|
|
||||||
"""
|
|
||||||
factor = self.frame_length_ms / self.frame_shift_ms
|
|
||||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
|
||||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
|
||||||
win_length = int(hop_length * factor)
|
|
||||||
return hop_length, win_length
|
|
||||||
|
|
||||||
### normalization ###
|
### normalization ###
|
||||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||||
|
@ -386,31 +380,6 @@ class AudioProcessor(object):
|
||||||
self.linear_scaler = StandardScaler()
|
self.linear_scaler = StandardScaler()
|
||||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||||
|
|
||||||
### DB and AMP conversion ###
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert amplitude values to decibels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Amplitude spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Decibels spectrogram.
|
|
||||||
"""
|
|
||||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Decibels spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Amplitude spectrogram.
|
|
||||||
"""
|
|
||||||
return _exp(x / self.spec_gain, self.base)
|
|
||||||
|
|
||||||
### Preemphasis ###
|
### Preemphasis ###
|
||||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||||
|
@ -424,32 +393,13 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Decorrelated audio signal.
|
np.ndarray: Decorrelated audio signal.
|
||||||
"""
|
"""
|
||||||
if self.preemphasis == 0:
|
return preemphasis(x=x, coef=self.preemphasis)
|
||||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
|
||||||
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
|
||||||
|
|
||||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Reverse pre-emphasis."""
|
"""Reverse pre-emphasis."""
|
||||||
if self.preemphasis == 0:
|
return deemphasis(x=x, coef=self.preemphasis)
|
||||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
|
||||||
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
|
||||||
|
|
||||||
### SPECTROGRAMs ###
|
### SPECTROGRAMs ###
|
||||||
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
|
|
||||||
"""Project a full scale spectrogram to a melspectrogram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spectrogram (np.ndarray): Full scale spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Melspectrogram
|
|
||||||
"""
|
|
||||||
return np.dot(self.mel_basis, spectrogram)
|
|
||||||
|
|
||||||
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert a melspectrogram to full scale spectrogram."""
|
|
||||||
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
|
|
||||||
|
|
||||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
"""Compute a spectrogram from a waveform.
|
"""Compute a spectrogram from a waveform.
|
||||||
|
|
||||||
|
@ -460,11 +410,16 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Spectrogram.
|
np.ndarray: Spectrogram.
|
||||||
"""
|
"""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
if self.do_amp_to_db_linear:
|
if self.do_amp_to_db_linear:
|
||||||
S = self._amp_to_db(np.abs(D))
|
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||||
else:
|
else:
|
||||||
S = np.abs(D)
|
S = np.abs(D)
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
@ -472,32 +427,35 @@ class AudioProcessor(object):
|
||||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
"""Compute a melspectrogram from a waveform."""
|
"""Compute a melspectrogram from a waveform."""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
|
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||||
if self.do_amp_to_db_mel:
|
if self.do_amp_to_db_mel:
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||||
else:
|
|
||||||
S = self._linear_to_mel(np.abs(D))
|
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
S = self.denormalize(spectrogram)
|
S = self.denormalize(spectrogram)
|
||||||
S = self._db_to_amp(S)
|
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
D = self.denormalize(mel_spectrogram)
|
D = self.denormalize(mel_spectrogram)
|
||||||
S = self._db_to_amp(D)
|
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||||
S = self._mel_to_linear(S) # Convert back to linear
|
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||||
|
@ -509,60 +467,22 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Normalized melspectrogram.
|
np.ndarray: Normalized melspectrogram.
|
||||||
"""
|
"""
|
||||||
S = self.denormalize(linear_spec)
|
S = self.denormalize(linear_spec)
|
||||||
S = self._db_to_amp(S)
|
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||||
S = self._linear_to_mel(np.abs(S))
|
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||||
S = self._amp_to_db(S)
|
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||||
mel = self.normalize(S)
|
mel = self.normalize(S)
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
### STFT and ISTFT ###
|
def _griffin_lim(self, S):
|
||||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
return griffin_lim(
|
||||||
"""Librosa STFT wrapper.
|
spec=S,
|
||||||
|
num_iter=self.griffin_lim_iters,
|
||||||
Args:
|
|
||||||
y (np.ndarray): Audio signal.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Complex number array.
|
|
||||||
"""
|
|
||||||
return librosa.stft(
|
|
||||||
y=y,
|
|
||||||
n_fft=self.fft_size,
|
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
|
fft_size=self.fft_size,
|
||||||
pad_mode=self.stft_pad_mode,
|
pad_mode=self.stft_pad_mode,
|
||||||
window="hann",
|
|
||||||
center=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
|
||||||
"""Librosa iSTFT wrapper."""
|
|
||||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
|
||||||
|
|
||||||
def _griffin_lim(self, S):
|
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
|
||||||
try:
|
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
|
||||||
except AttributeError: # np.complex is deprecated since numpy 1.20.0
|
|
||||||
S_complex = np.abs(S).astype(complex)
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
if not np.isfinite(y).all():
|
|
||||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
|
||||||
return np.array([0.0])
|
|
||||||
for _ in range(self.griffin_lim_iters):
|
|
||||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def compute_stft_paddings(self, x, pad_sides=1):
|
|
||||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
|
||||||
(first and final frames)"""
|
|
||||||
assert pad_sides in (1, 2)
|
|
||||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
|
||||||
if pad_sides == 1:
|
|
||||||
return 0, pad
|
|
||||||
return pad // 2, pad // 2 + pad % 2
|
|
||||||
|
|
||||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
|
@ -581,8 +501,6 @@ class AudioProcessor(object):
|
||||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||||
>>> pitch = ap.compute_f0(wav)
|
>>> pitch = ap.compute_f0(wav)
|
||||||
"""
|
"""
|
||||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
|
||||||
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
|
||||||
# align F0 length to the spectrogram length
|
# align F0 length to the spectrogram length
|
||||||
if len(x) % self.hop_length == 0:
|
if len(x) % self.hop_length == 0:
|
||||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||||
|
@ -612,21 +530,24 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
int: Last point without silence.
|
int: Last point without silence.
|
||||||
"""
|
"""
|
||||||
window_length = int(self.sample_rate * min_silence_sec)
|
return find_endpoint(
|
||||||
hop_length = int(window_length / 4)
|
wav=wav,
|
||||||
threshold = self._db_to_amp(-self.trim_db)
|
trim_db=self.trim_db,
|
||||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
sample_rate=self.sample_rate,
|
||||||
if np.max(wav[x : x + window_length]) < threshold:
|
min_silence_sec=min_silence_sec,
|
||||||
return x + hop_length
|
gain=self.spec_gain,
|
||||||
return len(wav)
|
base=self.base,
|
||||||
|
)
|
||||||
|
|
||||||
def trim_silence(self, wav):
|
def trim_silence(self, wav):
|
||||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||||
margin = int(self.sample_rate * 0.01)
|
return trim_silence(
|
||||||
wav = wav[margin:-margin]
|
wav=wav,
|
||||||
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
|
sample_rate=self.sample_rate,
|
||||||
0
|
trim_db=self.trim_db,
|
||||||
]
|
win_length=self.win_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||||
|
@ -638,13 +559,7 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Volume normalized waveform.
|
np.ndarray: Volume normalized waveform.
|
||||||
"""
|
"""
|
||||||
return x / abs(x).max() * 0.95
|
return volume_norm(x=x)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rms_norm(wav, db_level=-27):
|
|
||||||
r = 10 ** (db_level / 20)
|
|
||||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
|
||||||
return wav * a
|
|
||||||
|
|
||||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||||
"""Normalize the volume based on RMS of the signal.
|
"""Normalize the volume based on RMS of the signal.
|
||||||
|
@ -657,9 +572,7 @@ class AudioProcessor(object):
|
||||||
"""
|
"""
|
||||||
if db_level is None:
|
if db_level is None:
|
||||||
db_level = self.db_level
|
db_level = self.db_level
|
||||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
return rms_volume_norm(x=x, db_level=db_level)
|
||||||
wav = self._rms_norm(x, db_level)
|
|
||||||
return wav
|
|
||||||
|
|
||||||
### save and load ###
|
### save and load ###
|
||||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||||
|
@ -674,15 +587,10 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Loaded waveform.
|
np.ndarray: Loaded waveform.
|
||||||
"""
|
"""
|
||||||
if self.resample:
|
if sr is not None:
|
||||||
# loading with resampling. It is significantly slower.
|
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
|
||||||
elif sr is None:
|
|
||||||
# SF is faster than librosa for loading files
|
|
||||||
x, sr = sf.read(filename)
|
|
||||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
|
||||||
else:
|
else:
|
||||||
x, sr = librosa.load(filename, sr=sr)
|
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||||
if self.do_trim_silence:
|
if self.do_trim_silence:
|
||||||
try:
|
try:
|
||||||
x = self.trim_silence(x)
|
x = self.trim_silence(x)
|
||||||
|
@ -723,55 +631,3 @@ class AudioProcessor(object):
|
||||||
filename (str): Path to the wav file.
|
filename (str): Path to the wav file.
|
||||||
"""
|
"""
|
||||||
return librosa.get_duration(filename=filename)
|
return librosa.get_duration(filename=filename)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
|
||||||
mu = 2**qc - 1
|
|
||||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
|
||||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
|
||||||
# Quantize signal to the specified number of levels.
|
|
||||||
signal = (signal + 1) / 2 * mu + 0.5
|
|
||||||
return np.floor(
|
|
||||||
signal,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mulaw_decode(wav, qc):
|
|
||||||
"""Recovers waveform from quantized values."""
|
|
||||||
mu = 2**qc - 1
|
|
||||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode_16bits(x):
|
|
||||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
|
||||||
"""Quantize a waveform to a given number of bits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
|
||||||
bits (int): Number of quantization bits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Quantized waveform.
|
|
||||||
"""
|
|
||||||
return (x + 1.0) * (2**bits - 1) / 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dequantize(x, bits):
|
|
||||||
"""Dequantize a waveform from the given number of bits."""
|
|
||||||
return 2 * x / (2**bits - 1) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def _log(x, base):
|
|
||||||
if base == 10:
|
|
||||||
return np.log10(x)
|
|
||||||
return np.log(x)
|
|
||||||
|
|
||||||
|
|
||||||
def _exp(x, base):
|
|
||||||
if base == 10:
|
|
||||||
return np.power(10, x)
|
|
||||||
return np.exp(x)
|
|
||||||
|
|
146
TTS/utils/io.py
146
TTS/utils/io.py
|
@ -1,13 +1,9 @@
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
import shutil
|
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
|
@ -28,34 +24,6 @@ class AttrDict(dict):
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
|
||||||
"""Copy config.json and other model files to training folder and add
|
|
||||||
new fields.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Coqpit): Coqpit config defining the training run.
|
|
||||||
out_path (str): output path to copy the file.
|
|
||||||
new_fields (dict): new fileds to be added or edited
|
|
||||||
in the config file.
|
|
||||||
"""
|
|
||||||
copy_config_path = os.path.join(out_path, "config.json")
|
|
||||||
# add extra information fields
|
|
||||||
if new_fields:
|
|
||||||
config.update(new_fields, allow_new=True)
|
|
||||||
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
|
||||||
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
|
||||||
json.dump(config.to_dict(), f, indent=4)
|
|
||||||
|
|
||||||
# copy model stats file if available
|
|
||||||
if config.audio.stats_path is not None:
|
|
||||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
|
||||||
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
|
||||||
if not filesystem.exists(copy_stats_path):
|
|
||||||
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
|
||||||
with fsspec.open(copy_stats_path, "wb") as target_file:
|
|
||||||
shutil.copyfileobj(source_file, target_file)
|
|
||||||
|
|
||||||
|
|
||||||
def load_fsspec(
|
def load_fsspec(
|
||||||
path: str,
|
path: str,
|
||||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||||
|
@ -100,117 +68,3 @@ def load_checkpoint(
|
||||||
if eval:
|
if eval:
|
||||||
model.eval()
|
model.eval()
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
def save_fsspec(state: Any, path: str, **kwargs):
|
|
||||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: State object to save
|
|
||||||
path: Any path or url supported by fsspec.
|
|
||||||
**kwargs: Keyword arguments forwarded to torch.save.
|
|
||||||
"""
|
|
||||||
with fsspec.open(path, "wb") as f:
|
|
||||||
torch.save(state, f, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
|
||||||
if hasattr(model, "module"):
|
|
||||||
model_state = model.module.state_dict()
|
|
||||||
else:
|
|
||||||
model_state = model.state_dict()
|
|
||||||
if isinstance(optimizer, list):
|
|
||||||
optimizer_state = [optim.state_dict() for optim in optimizer]
|
|
||||||
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
|
|
||||||
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
|
|
||||||
else:
|
|
||||||
optimizer_state = optimizer.state_dict() if optimizer is not None else None
|
|
||||||
|
|
||||||
if isinstance(scaler, list):
|
|
||||||
scaler_state = [s.state_dict() for s in scaler]
|
|
||||||
else:
|
|
||||||
scaler_state = scaler.state_dict() if scaler is not None else None
|
|
||||||
|
|
||||||
if isinstance(config, Coqpit):
|
|
||||||
config = config.to_dict()
|
|
||||||
|
|
||||||
state = {
|
|
||||||
"config": config,
|
|
||||||
"model": model_state,
|
|
||||||
"optimizer": optimizer_state,
|
|
||||||
"scaler": scaler_state,
|
|
||||||
"step": current_step,
|
|
||||||
"epoch": epoch,
|
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
}
|
|
||||||
state.update(kwargs)
|
|
||||||
save_fsspec(state, output_path)
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
|
||||||
config,
|
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
current_step,
|
|
||||||
epoch,
|
|
||||||
output_folder,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
file_name = "checkpoint_{}.pth".format(current_step)
|
|
||||||
checkpoint_path = os.path.join(output_folder, file_name)
|
|
||||||
print("\n > CHECKPOINT : {}".format(checkpoint_path))
|
|
||||||
save_model(
|
|
||||||
config,
|
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
current_step,
|
|
||||||
epoch,
|
|
||||||
checkpoint_path,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(
|
|
||||||
current_loss,
|
|
||||||
best_loss,
|
|
||||||
config,
|
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
current_step,
|
|
||||||
epoch,
|
|
||||||
out_path,
|
|
||||||
keep_all_best=False,
|
|
||||||
keep_after=10000,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if current_loss < best_loss:
|
|
||||||
best_model_name = f"best_model_{current_step}.pth"
|
|
||||||
checkpoint_path = os.path.join(out_path, best_model_name)
|
|
||||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
|
||||||
save_model(
|
|
||||||
config,
|
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
current_step,
|
|
||||||
epoch,
|
|
||||||
checkpoint_path,
|
|
||||||
model_loss=current_loss,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
fs = fsspec.get_mapper(out_path).fs
|
|
||||||
# only delete previous if current is saved successfully
|
|
||||||
if not keep_all_best or (current_step < keep_after):
|
|
||||||
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
|
|
||||||
for model_name in model_names:
|
|
||||||
if os.path.basename(model_name) != best_model_name:
|
|
||||||
fs.rm(model_name)
|
|
||||||
# create a shortcut which always points to the currently best model
|
|
||||||
shortcut_name = "best_model.pth"
|
|
||||||
shortcut_path = os.path.join(out_path, shortcut_name)
|
|
||||||
fs.copy(checkpoint_path, shortcut_path)
|
|
||||||
best_loss = current_loss
|
|
||||||
return best_loss
|
|
||||||
|
|
|
@ -1,5 +1,278 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.vc.configs.shared_configs import BaseVCConfig
|
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||||
from TTS.vc.models.freevc import FreeVCArgs, FreeVCAudioConfig, FreeVCConfig
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCAudioConfig(Coqpit):
|
||||||
|
"""Audio configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_wav_value (float):
|
||||||
|
The maximum value of the waveform.
|
||||||
|
|
||||||
|
input_sample_rate (int):
|
||||||
|
The sampling rate of the input waveform.
|
||||||
|
|
||||||
|
output_sample_rate (int):
|
||||||
|
The sampling rate of the output waveform.
|
||||||
|
|
||||||
|
filter_length (int):
|
||||||
|
The length of the filter.
|
||||||
|
|
||||||
|
hop_length (int):
|
||||||
|
The hop length.
|
||||||
|
|
||||||
|
win_length (int):
|
||||||
|
The window length.
|
||||||
|
|
||||||
|
n_mel_channels (int):
|
||||||
|
The number of mel channels.
|
||||||
|
|
||||||
|
mel_fmin (float):
|
||||||
|
The minimum frequency of the mel filterbank.
|
||||||
|
|
||||||
|
mel_fmax (Optional[float]):
|
||||||
|
The maximum frequency of the mel filterbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_wav_value: float = field(default=32768.0)
|
||||||
|
input_sample_rate: int = field(default=16000)
|
||||||
|
output_sample_rate: int = field(default=24000)
|
||||||
|
filter_length: int = field(default=1280)
|
||||||
|
hop_length: int = field(default=320)
|
||||||
|
win_length: int = field(default=1280)
|
||||||
|
n_mel_channels: int = field(default=80)
|
||||||
|
mel_fmin: float = field(default=0.0)
|
||||||
|
mel_fmax: Optional[float] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCArgs(Coqpit):
|
||||||
|
"""FreeVC model arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_channels (int):
|
||||||
|
The number of channels in the spectrogram.
|
||||||
|
|
||||||
|
inter_channels (int):
|
||||||
|
The number of channels in the intermediate layers.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
The number of channels in the hidden layers.
|
||||||
|
|
||||||
|
filter_channels (int):
|
||||||
|
The number of channels in the filter layers.
|
||||||
|
|
||||||
|
n_heads (int):
|
||||||
|
The number of attention heads.
|
||||||
|
|
||||||
|
n_layers (int):
|
||||||
|
The number of layers.
|
||||||
|
|
||||||
|
kernel_size (int):
|
||||||
|
The size of the kernel.
|
||||||
|
|
||||||
|
p_dropout (float):
|
||||||
|
The dropout probability.
|
||||||
|
|
||||||
|
resblock (str):
|
||||||
|
The type of residual block.
|
||||||
|
|
||||||
|
resblock_kernel_sizes (List[int]):
|
||||||
|
The kernel sizes for the residual blocks.
|
||||||
|
|
||||||
|
resblock_dilation_sizes (List[List[int]]):
|
||||||
|
The dilation sizes for the residual blocks.
|
||||||
|
|
||||||
|
upsample_rates (List[int]):
|
||||||
|
The upsample rates.
|
||||||
|
|
||||||
|
upsample_initial_channel (int):
|
||||||
|
The number of channels in the initial upsample layer.
|
||||||
|
|
||||||
|
upsample_kernel_sizes (List[int]):
|
||||||
|
The kernel sizes for the upsample layers.
|
||||||
|
|
||||||
|
n_layers_q (int):
|
||||||
|
The number of layers in the quantization network.
|
||||||
|
|
||||||
|
use_spectral_norm (bool):
|
||||||
|
Whether to use spectral normalization.
|
||||||
|
|
||||||
|
gin_channels (int):
|
||||||
|
The number of channels in the global conditioning vector.
|
||||||
|
|
||||||
|
ssl_dim (int):
|
||||||
|
The dimension of the self-supervised learning embedding.
|
||||||
|
|
||||||
|
use_spk (bool):
|
||||||
|
Whether to use external speaker encoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
spec_channels: int = field(default=641)
|
||||||
|
inter_channels: int = field(default=192)
|
||||||
|
hidden_channels: int = field(default=192)
|
||||||
|
filter_channels: int = field(default=768)
|
||||||
|
n_heads: int = field(default=2)
|
||||||
|
n_layers: int = field(default=6)
|
||||||
|
kernel_size: int = field(default=3)
|
||||||
|
p_dropout: float = field(default=0.1)
|
||||||
|
resblock: str = field(default="1")
|
||||||
|
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||||
|
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
|
||||||
|
upsample_initial_channel: int = field(default=512)
|
||||||
|
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||||
|
n_layers_q: int = field(default=3)
|
||||||
|
use_spectral_norm: bool = field(default=False)
|
||||||
|
gin_channels: int = field(default=256)
|
||||||
|
ssl_dim: int = field(default=1024)
|
||||||
|
use_spk: bool = field(default=False)
|
||||||
|
num_spks: int = field(default=0)
|
||||||
|
segment_size: int = field(default=8960)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCConfig(BaseVCConfig):
|
||||||
|
"""Defines parameters for FreeVC End2End TTS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name. Do not change unless you know what you are doing.
|
||||||
|
|
||||||
|
model_args (FreeVCArgs):
|
||||||
|
Model architecture arguments. Defaults to `FreeVCArgs()`.
|
||||||
|
|
||||||
|
audio (FreeVCAudioConfig):
|
||||||
|
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
|
||||||
|
|
||||||
|
grad_clip (List):
|
||||||
|
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||||
|
|
||||||
|
lr_gen (float):
|
||||||
|
Initial learning rate for the generator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_disc (float):
|
||||||
|
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_scheduler_gen (str):
|
||||||
|
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_gen_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
lr_scheduler_disc (str):
|
||||||
|
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_disc_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
scheduler_after_epoch (bool):
|
||||||
|
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||||
|
|
||||||
|
optimizer (str):
|
||||||
|
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||||
|
`torch.optim.*`. Defaults to `AdamW`.
|
||||||
|
|
||||||
|
kl_loss_alpha (float):
|
||||||
|
Loss weight for KL loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
disc_loss_alpha (float):
|
||||||
|
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
gen_loss_alpha (float):
|
||||||
|
Loss weight for the generator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
feat_loss_alpha (float):
|
||||||
|
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
mel_loss_alpha (float):
|
||||||
|
Loss weight for the mel loss. Defaults to 45.0.
|
||||||
|
|
||||||
|
return_wav (bool):
|
||||||
|
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
compute_linear_spec (bool):
|
||||||
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
use_weighted_sampler (bool):
|
||||||
|
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||||
|
|
||||||
|
weighted_sampler_attrs (dict):
|
||||||
|
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||||
|
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||||
|
|
||||||
|
weighted_sampler_multipliers (dict):
|
||||||
|
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||||
|
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||||
|
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||||
|
|
||||||
|
r (int):
|
||||||
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
add_blank (bool):
|
||||||
|
If true, a blank token is added in between every character. Defaults to `True`.
|
||||||
|
|
||||||
|
test_sentences (List[List]):
|
||||||
|
List of sentences with speaker and language information to be used for testing.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language ids file.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
If true, language embedding is used. Defaults to `False`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
|
>>> config = FreeVCConfig()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "freevc"
|
||||||
|
# model specific params
|
||||||
|
model_args: FreeVCArgs = field(default_factory=FreeVCArgs)
|
||||||
|
audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
# TODO with training support
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
# TODO with training support
|
||||||
|
|
||||||
|
# data loader params
|
||||||
|
return_wav: bool = True
|
||||||
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# sampler params
|
||||||
|
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||||
|
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||||
|
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
add_blank: bool = True
|
||||||
|
|
||||||
|
# multi-speaker settings
|
||||||
|
# use speaker embedding layer
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers_file: str = None
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
|
|
||||||
|
# use d-vectors
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: List[str] = None
|
||||||
|
d_vector_dim: int = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for key, val in self.model_args.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
self[key] = val
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
|
@ -6,15 +5,17 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
from torch.nn.utils import spectral_norm
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
import TTS.vc.modules.freevc.modules as modules
|
import TTS.vc.modules.freevc.modules as modules
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.io import load_fsspec, save_checkpoint
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vc.configs.shared_configs import BaseVCConfig
|
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
from TTS.vc.models.base_vc import BaseVC
|
from TTS.vc.models.base_vc import BaseVC
|
||||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||||
|
@ -153,9 +154,9 @@ class Generator(torch.nn.Module):
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
print("Removing weight norm...")
|
print("Removing weight norm...")
|
||||||
for l in self.ups:
|
for l in self.ups:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.resblocks:
|
for l in self.resblocks:
|
||||||
l.remove_weight_norm()
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorP(torch.nn.Module):
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
@ -294,136 +295,6 @@ class SpeakerEncoder(torch.nn.Module):
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FreeVCAudioConfig(Coqpit):
|
|
||||||
"""Audio configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_wav_value (float):
|
|
||||||
The maximum value of the waveform.
|
|
||||||
|
|
||||||
input_sample_rate (int):
|
|
||||||
The sampling rate of the input waveform.
|
|
||||||
|
|
||||||
output_sample_rate (int):
|
|
||||||
The sampling rate of the output waveform.
|
|
||||||
|
|
||||||
filter_length (int):
|
|
||||||
The length of the filter.
|
|
||||||
|
|
||||||
hop_length (int):
|
|
||||||
The hop length.
|
|
||||||
|
|
||||||
win_length (int):
|
|
||||||
The window length.
|
|
||||||
|
|
||||||
n_mel_channels (int):
|
|
||||||
The number of mel channels.
|
|
||||||
|
|
||||||
mel_fmin (float):
|
|
||||||
The minimum frequency of the mel filterbank.
|
|
||||||
|
|
||||||
mel_fmax (Optional[float]):
|
|
||||||
The maximum frequency of the mel filterbank.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_wav_value: float = field(default=32768.0)
|
|
||||||
input_sample_rate: int = field(default=16000)
|
|
||||||
output_sample_rate: int = field(default=24000)
|
|
||||||
filter_length: int = field(default=1280)
|
|
||||||
hop_length: int = field(default=320)
|
|
||||||
win_length: int = field(default=1280)
|
|
||||||
n_mel_channels: int = field(default=80)
|
|
||||||
mel_fmin: float = field(default=0.0)
|
|
||||||
mel_fmax: Optional[float] = field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FreeVCArgs(Coqpit):
|
|
||||||
"""FreeVC model arguments
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec_channels (int):
|
|
||||||
The number of channels in the spectrogram.
|
|
||||||
|
|
||||||
inter_channels (int):
|
|
||||||
The number of channels in the intermediate layers.
|
|
||||||
|
|
||||||
hidden_channels (int):
|
|
||||||
The number of channels in the hidden layers.
|
|
||||||
|
|
||||||
filter_channels (int):
|
|
||||||
The number of channels in the filter layers.
|
|
||||||
|
|
||||||
n_heads (int):
|
|
||||||
The number of attention heads.
|
|
||||||
|
|
||||||
n_layers (int):
|
|
||||||
The number of layers.
|
|
||||||
|
|
||||||
kernel_size (int):
|
|
||||||
The size of the kernel.
|
|
||||||
|
|
||||||
p_dropout (float):
|
|
||||||
The dropout probability.
|
|
||||||
|
|
||||||
resblock (str):
|
|
||||||
The type of residual block.
|
|
||||||
|
|
||||||
resblock_kernel_sizes (List[int]):
|
|
||||||
The kernel sizes for the residual blocks.
|
|
||||||
|
|
||||||
resblock_dilation_sizes (List[List[int]]):
|
|
||||||
The dilation sizes for the residual blocks.
|
|
||||||
|
|
||||||
upsample_rates (List[int]):
|
|
||||||
The upsample rates.
|
|
||||||
|
|
||||||
upsample_initial_channel (int):
|
|
||||||
The number of channels in the initial upsample layer.
|
|
||||||
|
|
||||||
upsample_kernel_sizes (List[int]):
|
|
||||||
The kernel sizes for the upsample layers.
|
|
||||||
|
|
||||||
n_layers_q (int):
|
|
||||||
The number of layers in the quantization network.
|
|
||||||
|
|
||||||
use_spectral_norm (bool):
|
|
||||||
Whether to use spectral normalization.
|
|
||||||
|
|
||||||
gin_channels (int):
|
|
||||||
The number of channels in the global conditioning vector.
|
|
||||||
|
|
||||||
ssl_dim (int):
|
|
||||||
The dimension of the self-supervised learning embedding.
|
|
||||||
|
|
||||||
use_spk (bool):
|
|
||||||
Whether to use external speaker encoder.
|
|
||||||
"""
|
|
||||||
|
|
||||||
spec_channels: int = field(default=641)
|
|
||||||
inter_channels: int = field(default=192)
|
|
||||||
hidden_channels: int = field(default=192)
|
|
||||||
filter_channels: int = field(default=768)
|
|
||||||
n_heads: int = field(default=2)
|
|
||||||
n_layers: int = field(default=6)
|
|
||||||
kernel_size: int = field(default=3)
|
|
||||||
p_dropout: float = field(default=0.1)
|
|
||||||
resblock: str = field(default="1")
|
|
||||||
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
|
|
||||||
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
|
||||||
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
|
|
||||||
upsample_initial_channel: int = field(default=512)
|
|
||||||
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
|
||||||
n_layers_q: int = field(default=3)
|
|
||||||
use_spectral_norm: bool = field(default=False)
|
|
||||||
gin_channels: int = field(default=256)
|
|
||||||
ssl_dim: int = field(default=1024)
|
|
||||||
use_spk: bool = field(default=False)
|
|
||||||
num_spks: int = field(default=0)
|
|
||||||
segment_size: int = field(default=8960)
|
|
||||||
|
|
||||||
|
|
||||||
class FreeVC(BaseVC):
|
class FreeVC(BaseVC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -677,7 +548,7 @@ class FreeVC(BaseVC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||||
model = FreeVC(config)
|
model = FreeVC(config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -689,145 +560,3 @@ class FreeVC(BaseVC):
|
||||||
|
|
||||||
def train_step():
|
def train_step():
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FreeVCConfig(BaseVCConfig):
|
|
||||||
"""Defines parameters for FreeVC End2End TTS model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str):
|
|
||||||
Model name. Do not change unless you know what you are doing.
|
|
||||||
|
|
||||||
model_args (FreeVCArgs):
|
|
||||||
Model architecture arguments. Defaults to `FreeVCArgs()`.
|
|
||||||
|
|
||||||
audio (FreeVCAudioConfig):
|
|
||||||
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
|
|
||||||
|
|
||||||
grad_clip (List):
|
|
||||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
|
||||||
|
|
||||||
lr_gen (float):
|
|
||||||
Initial learning rate for the generator. Defaults to 0.0002.
|
|
||||||
|
|
||||||
lr_disc (float):
|
|
||||||
Initial learning rate for the discriminator. Defaults to 0.0002.
|
|
||||||
|
|
||||||
lr_scheduler_gen (str):
|
|
||||||
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
|
||||||
`ExponentialLR`.
|
|
||||||
|
|
||||||
lr_scheduler_gen_params (dict):
|
|
||||||
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
|
||||||
|
|
||||||
lr_scheduler_disc (str):
|
|
||||||
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
|
||||||
`ExponentialLR`.
|
|
||||||
|
|
||||||
lr_scheduler_disc_params (dict):
|
|
||||||
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
|
||||||
|
|
||||||
scheduler_after_epoch (bool):
|
|
||||||
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
|
||||||
|
|
||||||
optimizer (str):
|
|
||||||
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
|
||||||
`torch.optim.*`. Defaults to `AdamW`.
|
|
||||||
|
|
||||||
kl_loss_alpha (float):
|
|
||||||
Loss weight for KL loss. Defaults to 1.0.
|
|
||||||
|
|
||||||
disc_loss_alpha (float):
|
|
||||||
Loss weight for the discriminator loss. Defaults to 1.0.
|
|
||||||
|
|
||||||
gen_loss_alpha (float):
|
|
||||||
Loss weight for the generator loss. Defaults to 1.0.
|
|
||||||
|
|
||||||
feat_loss_alpha (float):
|
|
||||||
Loss weight for the feature matching loss. Defaults to 1.0.
|
|
||||||
|
|
||||||
mel_loss_alpha (float):
|
|
||||||
Loss weight for the mel loss. Defaults to 45.0.
|
|
||||||
|
|
||||||
return_wav (bool):
|
|
||||||
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
|
||||||
|
|
||||||
compute_linear_spec (bool):
|
|
||||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
|
||||||
|
|
||||||
use_weighted_sampler (bool):
|
|
||||||
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
|
||||||
|
|
||||||
weighted_sampler_attrs (dict):
|
|
||||||
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
|
||||||
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
|
||||||
|
|
||||||
weighted_sampler_multipliers (dict):
|
|
||||||
Weight each unique value of a key returned by the formatter for weighted sampling.
|
|
||||||
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
|
||||||
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
|
||||||
|
|
||||||
r (int):
|
|
||||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
|
||||||
|
|
||||||
add_blank (bool):
|
|
||||||
If true, a blank token is added in between every character. Defaults to `True`.
|
|
||||||
|
|
||||||
test_sentences (List[List]):
|
|
||||||
List of sentences with speaker and language information to be used for testing.
|
|
||||||
|
|
||||||
language_ids_file (str):
|
|
||||||
Path to the language ids file.
|
|
||||||
|
|
||||||
use_language_embedding (bool):
|
|
||||||
If true, language embedding is used. Defaults to `False`.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
>>> from TTS.tts.configs.freevc_config import FreeVCConfig
|
|
||||||
>>> config = FreeVCConfig()
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str = "freevc"
|
|
||||||
# model specific params
|
|
||||||
model_args: FreeVCArgs = field(default_factory=FreeVCArgs)
|
|
||||||
audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig)
|
|
||||||
|
|
||||||
# optimizer
|
|
||||||
# TODO with training support
|
|
||||||
|
|
||||||
# loss params
|
|
||||||
# TODO with training support
|
|
||||||
|
|
||||||
# data loader params
|
|
||||||
return_wav: bool = True
|
|
||||||
compute_linear_spec: bool = True
|
|
||||||
|
|
||||||
# sampler params
|
|
||||||
use_weighted_sampler: bool = False # TODO: move it to the base config
|
|
||||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
|
||||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
# overrides
|
|
||||||
r: int = 1 # DO NOT CHANGE
|
|
||||||
add_blank: bool = True
|
|
||||||
|
|
||||||
# multi-speaker settings
|
|
||||||
# use speaker embedding layer
|
|
||||||
num_speakers: int = 0
|
|
||||||
speakers_file: str = None
|
|
||||||
speaker_embedding_channels: int = 256
|
|
||||||
|
|
||||||
# use d-vectors
|
|
||||||
use_d_vector_file: bool = False
|
|
||||||
d_vector_file: List[str] = None
|
|
||||||
d_vector_dim: int = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
for key, val in self.model_args.items():
|
|
||||||
if hasattr(self, key):
|
|
||||||
self[key] = val
|
|
||||||
|
|
|
@ -1,13 +1,9 @@
|
||||||
import copy
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
from torch.nn import Conv1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||||
|
@ -122,7 +118,7 @@ class WN(torch.nn.Module):
|
||||||
|
|
||||||
if gin_channels != 0:
|
if gin_channels != 0:
|
||||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||||
|
|
||||||
for i in range(n_layers):
|
for i in range(n_layers):
|
||||||
dilation = dilation_rate**i
|
dilation = dilation_rate**i
|
||||||
|
@ -130,7 +126,7 @@ class WN(torch.nn.Module):
|
||||||
in_layer = torch.nn.Conv1d(
|
in_layer = torch.nn.Conv1d(
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||||
)
|
)
|
||||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||||
self.in_layers.append(in_layer)
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
# last one is not necessary
|
# last one is not necessary
|
||||||
|
@ -140,7 +136,7 @@ class WN(torch.nn.Module):
|
||||||
res_skip_channels = hidden_channels
|
res_skip_channels = hidden_channels
|
||||||
|
|
||||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||||
self.res_skip_layers.append(res_skip_layer)
|
self.res_skip_layers.append(res_skip_layer)
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None, **kwargs):
|
def forward(self, x, x_mask, g=None, **kwargs):
|
||||||
|
@ -172,11 +168,11 @@ class WN(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
if self.gin_channels != 0:
|
if self.gin_channels != 0:
|
||||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
remove_parametrizations(self.cond_layer, "weight")
|
||||||
for l in self.in_layers:
|
for l in self.in_layers:
|
||||||
torch.nn.utils.remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.res_skip_layers:
|
for l in self.res_skip_layers:
|
||||||
torch.nn.utils.remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock1(torch.nn.Module):
|
||||||
|
@ -250,9 +246,9 @@ class ResBlock1(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs1:
|
for l in self.convs1:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.convs2:
|
for l in self.convs2:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class ResBlock2(torch.nn.Module):
|
class ResBlock2(torch.nn.Module):
|
||||||
|
@ -297,7 +293,7 @@ class ResBlock2(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class Log(nn.Module):
|
class Log(nn.Module):
|
||||||
|
|
|
@ -497,7 +497,7 @@ class TransformerEncoder(nn.Module):
|
||||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
nn.init.constant_(self.pos_conv.bias, 0)
|
nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
if hasattr(args, "relative_position_embedding"):
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
|
|
@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
steps_to_start_discriminator: int = 200000
|
steps_to_start_discriminator: int = 200000
|
||||||
|
target_loss: str = "loss_1"
|
||||||
|
|
||||||
# LOSS PARAMETERS - overrides
|
# LOSS PARAMETERS - overrides
|
||||||
use_stft_loss: bool = True
|
use_stft_loss: bool = True
|
||||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||||
|
|
||||||
|
|
||||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||||
|
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||||
mel = ap.melspectrogram(y)
|
mel = ap.melspectrogram(y)
|
||||||
np.save(mel_path, mel)
|
np.save(mel_path, mel)
|
||||||
if isinstance(config.mode, int):
|
if isinstance(config.mode, int):
|
||||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
|
quant = (
|
||||||
|
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||||
|
if config.model_args.mulaw
|
||||||
|
else quantize(x=y, quantize_bits=config.mode)
|
||||||
|
)
|
||||||
np.save(quant_path, quant)
|
np.save(quant_path, quant)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||||
|
|
||||||
|
|
||||||
class WaveRNNDataset(Dataset):
|
class WaveRNNDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
|
||||||
x_input = audio
|
x_input = audio
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
x_input = (
|
x_input = (
|
||||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||||
|
if self.mulaw
|
||||||
|
else quantize(x=audio, quantize_bits=self.mode)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
@ -10,14 +11,16 @@ class ResStack(nn.Module):
|
||||||
resstack += [
|
resstack += [
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
nn.ReflectionPad1d(dilation),
|
nn.ReflectionPad1d(dilation),
|
||||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
|
nn.utils.parametrizations.weight_norm(
|
||||||
|
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
|
||||||
|
),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
nn.ReflectionPad1d(padding),
|
nn.ReflectionPad1d(padding),
|
||||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||||
]
|
]
|
||||||
self.resstack = nn.Sequential(*resstack)
|
self.resstack = nn.Sequential(*resstack)
|
||||||
|
|
||||||
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x1 = self.shortcut(x)
|
x1 = self.shortcut(x)
|
||||||
|
@ -25,13 +28,13 @@ class ResStack(nn.Module):
|
||||||
return x1 + x2
|
return x1 + x2
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.shortcut)
|
remove_parametrizations(self.shortcut, "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[2])
|
remove_parametrizations(self.resstack[2], "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[5])
|
remove_parametrizations(self.resstack[5], "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[8])
|
remove_parametrizations(self.resstack[8], "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[11])
|
remove_parametrizations(self.resstack[11], "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[14])
|
remove_parametrizations(self.resstack[14], "weight")
|
||||||
nn.utils.remove_weight_norm(self.resstack[17])
|
remove_parametrizations(self.resstack[17], "weight")
|
||||||
|
|
||||||
|
|
||||||
class MRF(nn.Module):
|
class MRF(nn.Module):
|
||||||
|
|
|
@ -195,10 +195,10 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||||
if isinstance(scores_fake, list):
|
if isinstance(scores_fake, list):
|
||||||
# multi-scale loss
|
# multi-scale loss
|
||||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||||
total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
|
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
|
||||||
loss += total_loss
|
loss += total_loss
|
||||||
real_loss += real_loss
|
real_loss += real_loss_
|
||||||
fake_loss += fake_loss
|
fake_loss += fake_loss_
|
||||||
# normalize loss values with number of scales (discriminators)
|
# normalize loss values with number of scales (discriminators)
|
||||||
loss /= len(scores_fake)
|
loss /= len(scores_fake)
|
||||||
real_loss /= len(scores_real)
|
real_loss /= len(scores_real)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
|
|
||||||
class ResidualStack(nn.Module):
|
class ResidualStack(nn.Module):
|
||||||
|
@ -27,7 +28,7 @@ class ResidualStack(nn.Module):
|
||||||
]
|
]
|
||||||
|
|
||||||
self.shortcuts = nn.ModuleList(
|
self.shortcuts = nn.ModuleList(
|
||||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)]
|
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -37,6 +38,6 @@ class ResidualStack(nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||||
nn.utils.remove_weight_norm(block[2])
|
remove_parametrizations(block[2], "weight")
|
||||||
nn.utils.remove_weight_norm(block[4])
|
remove_parametrizations(block[4], "weight")
|
||||||
nn.utils.remove_weight_norm(shortcut)
|
remove_parametrizations(shortcut, "weight")
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(nn.Conv1d):
|
class Conv1d(nn.Conv1d):
|
||||||
|
@ -56,8 +57,8 @@ class FiLM(nn.Module):
|
||||||
return shift, scale
|
return shift, scale
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.input_conv)
|
remove_parametrizations(self.input_conv, "weight")
|
||||||
nn.utils.remove_weight_norm(self.output_conv)
|
remove_parametrizations(self.output_conv, "weight")
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
self.input_conv = weight_norm(self.input_conv)
|
self.input_conv = weight_norm(self.input_conv)
|
||||||
|
@ -111,13 +112,13 @@ class UBlock(nn.Module):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.res_block)
|
remove_parametrizations(self.res_block, "weight")
|
||||||
for _, layer in enumerate(self.main_block):
|
for _, layer in enumerate(self.main_block):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
for _, layer in enumerate(self.out_block):
|
for _, layer in enumerate(self.out_block):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
self.res_block = weight_norm(self.res_block)
|
self.res_block = weight_norm(self.res_block)
|
||||||
|
@ -153,10 +154,10 @@ class DBlock(nn.Module):
|
||||||
return o + res
|
return o + res
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.res_block)
|
remove_parametrizations(self.res_block, "weight")
|
||||||
for _, layer in enumerate(self.main_block):
|
for _, layer in enumerate(self.main_block):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
self.res_block = weight_norm(self.res_block)
|
self.res_block = weight_norm(self.res_block)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.period = period
|
self.period = period
|
||||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
@ -125,7 +125,7 @@ class DiscriminatorS(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||||
|
|
|
@ -3,7 +3,8 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -99,9 +100,9 @@ class ResBlock1(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs1:
|
for l in self.convs1:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.convs2:
|
for l in self.convs2:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class ResBlock2(torch.nn.Module):
|
class ResBlock2(torch.nn.Module):
|
||||||
|
@ -155,7 +156,7 @@ class ResBlock2(torch.nn.Module):
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class HifiganGenerator(torch.nn.Module):
|
class HifiganGenerator(torch.nn.Module):
|
||||||
|
@ -227,10 +228,10 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
if not conv_pre_weight_norm:
|
if not conv_pre_weight_norm:
|
||||||
remove_weight_norm(self.conv_pre)
|
remove_parametrizations(self.conv_pre, "weight")
|
||||||
|
|
||||||
if not conv_post_weight_norm:
|
if not conv_post_weight_norm:
|
||||||
remove_weight_norm(self.conv_post)
|
remove_parametrizations(self.conv_post, "weight")
|
||||||
|
|
||||||
def forward(self, x, g=None):
|
def forward(self, x, g=None):
|
||||||
"""
|
"""
|
||||||
|
@ -283,11 +284,11 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
print("Removing weight norm...")
|
print("Removing weight norm...")
|
||||||
for l in self.ups:
|
for l in self.ups:
|
||||||
remove_weight_norm(l)
|
remove_parametrizations(l, "weight")
|
||||||
for l in self.resblocks:
|
for l in self.resblocks:
|
||||||
l.remove_weight_norm()
|
l.remove_weight_norm()
|
||||||
remove_weight_norm(self.conv_pre)
|
remove_parametrizations(self.conv_pre, "weight")
|
||||||
remove_weight_norm(self.conv_post)
|
remove_parametrizations(self.conv_post, "weight")
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False, cache=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
|
||||||
class MelganDiscriminator(nn.Module):
|
class MelganDiscriminator(nn.Module):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.melgan import ResidualStack
|
from TTS.vocoder.layers.melgan import ResidualStack
|
||||||
|
@ -80,7 +80,7 @@ class MelganGenerator(nn.Module):
|
||||||
for _, layer in enumerate(self.layers):
|
for _, layer in enumerate(self.layers):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
try:
|
try:
|
||||||
nn.utils.remove_weight_norm(layer)
|
nn.utils.parametrize.remove_parametrizations(layer, "weight")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ class ParallelWaveganDiscriminator(nn.Module):
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
def _apply_weight_norm(m):
|
def _apply_weight_norm(m):
|
||||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||||
torch.nn.utils.weight_norm(m)
|
torch.nn.utils.parametrizations.weight_norm(m)
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ class ParallelWaveganDiscriminator(nn.Module):
|
||||||
def _remove_weight_norm(m):
|
def _remove_weight_norm(m):
|
||||||
try:
|
try:
|
||||||
# print(f"Weight norm is removed from {m}.")
|
# print(f"Weight norm is removed from {m}.")
|
||||||
nn.utils.remove_weight_norm(m)
|
remove_parametrizations(m, "weight")
|
||||||
except ValueError: # this module didn't have weight norm
|
except ValueError: # this module didn't have weight norm
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -171,7 +172,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
def _apply_weight_norm(m):
|
def _apply_weight_norm(m):
|
||||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||||
torch.nn.utils.weight_norm(m)
|
torch.nn.utils.parametrizations.weight_norm(m)
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
@ -179,7 +180,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||||
def _remove_weight_norm(m):
|
def _remove_weight_norm(m):
|
||||||
try:
|
try:
|
||||||
print(f"Weight norm is removed from {m}.")
|
print(f"Weight norm is removed from {m}.")
|
||||||
nn.utils.remove_weight_norm(m)
|
remove_parametrizations(m, "weight")
|
||||||
except ValueError: # this module didn't have weight norm
|
except ValueError: # this module didn't have weight norm
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||||
|
@ -126,7 +127,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
def _remove_weight_norm(m):
|
def _remove_weight_norm(m):
|
||||||
try:
|
try:
|
||||||
# print(f"Weight norm is removed from {m}.")
|
# print(f"Weight norm is removed from {m}.")
|
||||||
torch.nn.utils.remove_weight_norm(m)
|
remove_parametrizations(m, "weight")
|
||||||
except ValueError: # this module didn't have weight norm
|
except ValueError: # this module didn't have weight norm
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -135,7 +136,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
def _apply_weight_norm(m):
|
def _apply_weight_norm(m):
|
||||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||||
torch.nn.utils.weight_norm(m)
|
torch.nn.utils.parametrizations.weight_norm(m)
|
||||||
# print(f"Weight norm is applied to {m}.")
|
# print(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import spectral_norm, weight_norm
|
from torch.nn.utils import spectral_norm
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import parametrize
|
||||||
|
|
||||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class UnivnetGenerator(torch.nn.Module):
|
||||||
def _remove_weight_norm(m):
|
def _remove_weight_norm(m):
|
||||||
try:
|
try:
|
||||||
# print(f"Weight norm is removed from {m}.")
|
# print(f"Weight norm is removed from {m}.")
|
||||||
torch.nn.utils.remove_weight_norm(m)
|
parametrize.remove_parametrizations(m, "weight")
|
||||||
except ValueError: # this module didn't have weight norm
|
except ValueError: # this module didn't have weight norm
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -124,7 +125,7 @@ class UnivnetGenerator(torch.nn.Module):
|
||||||
|
|
||||||
def _apply_weight_norm(m):
|
def _apply_weight_norm(m):
|
||||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||||
torch.nn.utils.weight_norm(m)
|
torch.nn.utils.parametrizations.weight_norm(m)
|
||||||
# print(f"Weight norm is applied to {m}.")
|
# print(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
|
@ -5,7 +5,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
@ -178,27 +179,27 @@ class Wavegrad(BaseVocoder):
|
||||||
for _, layer in enumerate(self.dblocks):
|
for _, layer in enumerate(self.dblocks):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
try:
|
try:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
for _, layer in enumerate(self.film):
|
for _, layer in enumerate(self.film):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
try:
|
try:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
for _, layer in enumerate(self.ublocks):
|
for _, layer in enumerate(self.ublocks):
|
||||||
if len(layer.state_dict()) != 0:
|
if len(layer.state_dict()) != 0:
|
||||||
try:
|
try:
|
||||||
nn.utils.remove_weight_norm(layer)
|
remove_parametrizations(layer, "weight")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
nn.utils.remove_weight_norm(self.x_conv)
|
remove_parametrizations(self.x_conv, "weight")
|
||||||
nn.utils.remove_weight_norm(self.out_conv)
|
remove_parametrizations(self.out_conv, "weight")
|
||||||
nn.utils.remove_weight_norm(self.y_conv)
|
remove_parametrizations(self.y_conv, "weight")
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
for _, layer in enumerate(self.dblocks):
|
for _, layer in enumerate(self.dblocks):
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||||
|
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||||
|
|
||||||
# Fade-out at the end to avoid signal cutting out suddenly
|
# Fade-out at the end to avoid signal cutting out suddenly
|
||||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||||
|
|
|
@ -124,7 +124,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(TTS().list_models())
|
print(TTS().list_models())
|
||||||
|
|
||||||
# Init TTS
|
# Init TTS
|
||||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
||||||
|
|
||||||
# Run TTS
|
# Run TTS
|
||||||
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
||||||
|
@ -198,19 +198,12 @@ from TTS.api import CS_API
|
||||||
# Init 🐸 Coqui Studio API
|
# Init 🐸 Coqui Studio API
|
||||||
# you can either set the API token as an environment variable `COQUI_STUDIO_TOKEN` or pass it as an argument.
|
# you can either set the API token as an environment variable `COQUI_STUDIO_TOKEN` or pass it as an argument.
|
||||||
|
|
||||||
# XTTS - Best quality and life-like speech in EN
|
# XTTS - Best quality and life-like speech in multiple languages. See https://docs.coqui.ai/reference/samples_xtts_create for supported languages.
|
||||||
api = CS_API(api_token=<token>, model="XTTS")
|
api = CS_API(api_token=<token>, model="XTTS")
|
||||||
api.speakers # all the speakers are available with all the models.
|
api.speakers # all the speakers are available with all the models.
|
||||||
api.list_speakers()
|
api.list_speakers()
|
||||||
api.list_voices()
|
api.list_voices()
|
||||||
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
|
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", language="en", speed=1.5)
|
||||||
|
|
||||||
# XTTS-multilingual - Multilingual XTTS with [en, de, es, fr, it, pt, ...] (more langs coming soon)
|
|
||||||
api = CS_API(api_token=<token>, model="XTTS-multilingual")
|
|
||||||
api.speakers
|
|
||||||
api.list_speakers()
|
|
||||||
api.list_voices()
|
|
||||||
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
|
|
||||||
|
|
||||||
# V1 - Fast and lightweight TTS in EN with emotion control.
|
# V1 - Fast and lightweight TTS in EN with emotion control.
|
||||||
api = CS_API(api_token=<token>, model="V1")
|
api = CS_API(api_token=<token>, model="V1")
|
||||||
|
|
|
@ -7,17 +7,24 @@ This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui
|
||||||
a few tricks to make it faster and support streaming inference.
|
a few tricks to make it faster and support streaming inference.
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
- Voice cloning with just a 3-second audio clip.
|
- Voice cloning.
|
||||||
- Cross-language voice cloning.
|
- Cross-language voice cloning.
|
||||||
- Multi-lingual speech generation.
|
- Multi-lingual speech generation.
|
||||||
- 24khz sampling rate.
|
- 24khz sampling rate.
|
||||||
|
- Streaming inference with < 200ms latency. (See [Streaming inference](#streaming-inference))
|
||||||
|
- Fine-tuning support. (See [Training](#training))
|
||||||
|
|
||||||
|
### Updates with v2
|
||||||
|
- Improved voice cloning.
|
||||||
|
- Voices can be cloned with a single audio file or multiple audio files, without any effect on the runtime.
|
||||||
|
- 2 new languages: Hungarian and Korean.
|
||||||
|
- Across the board quality improvements.
|
||||||
|
|
||||||
### Code
|
### Code
|
||||||
Current implementation only supports inference.
|
Current implementation only supports inference.
|
||||||
|
|
||||||
### Languages
|
### Languages
|
||||||
As of now, XTTS-v1.1 supports 14 languages: English, Spanish, French, German, Italian, Portuguese,
|
As of now, XTTS-v2 supports 16 languages: English (en), Spanish (es), French (fr), German (de), Italian (it), Portuguese (pt), Polish (pl), Turkish (tr), Russian (ru), Dutch (nl), Czech (cs), Arabic (ar), Chinese (zh-cn), Japanese (ja), Hungarian (hu) and Korean (ko).
|
||||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified) and Japanese.
|
|
||||||
|
|
||||||
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
|
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
|
||||||
|
|
||||||
|
@ -31,27 +38,60 @@ You can also mail us at info@coqui.ai.
|
||||||
### Inference
|
### Inference
|
||||||
#### 🐸TTS API
|
#### 🐸TTS API
|
||||||
|
|
||||||
|
##### Single reference
|
||||||
```python
|
```python
|
||||||
from TTS.api import TTS
|
from TTS.api import TTS
|
||||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1.1", gpu=True)
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
||||||
|
|
||||||
# generate speech by cloning a voice using default settings
|
# generate speech by cloning a voice using default settings
|
||||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
file_path="output.wav",
|
file_path="output.wav",
|
||||||
speaker_wav="/path/to/target/speaker.wav",
|
speaker_wav=["/path/to/target/speaker.wav"],
|
||||||
|
language="en")
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Multiple references
|
||||||
|
```python
|
||||||
|
from TTS.api import TTS
|
||||||
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
||||||
|
|
||||||
|
# generate speech by cloning a voice using default settings
|
||||||
|
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
file_path="output.wav",
|
||||||
|
speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"],
|
||||||
language="en")
|
language="en")
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 🐸TTS Command line
|
#### 🐸TTS Command line
|
||||||
|
|
||||||
|
##### Single reference
|
||||||
```console
|
```console
|
||||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 \
|
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||||
--text "Bugün okula gitmek istemiyorum." \
|
--text "Bugün okula gitmek istemiyorum." \
|
||||||
--speaker_wav /path/to/target/speaker.wav \
|
--speaker_wav /path/to/target/speaker.wav \
|
||||||
--language_idx tr \
|
--language_idx tr \
|
||||||
--use_cuda true
|
--use_cuda true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### Multiple references
|
||||||
|
```console
|
||||||
|
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||||
|
--text "Bugün okula gitmek istemiyorum." \
|
||||||
|
--speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \
|
||||||
|
--language_idx tr \
|
||||||
|
--use_cuda true
|
||||||
|
```
|
||||||
|
or for all wav files in a directory you can use:
|
||||||
|
|
||||||
|
```console
|
||||||
|
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||||
|
--text "Bugün okula gitmek istemiyorum." \
|
||||||
|
--speaker_wav /path/to/target/*.wav \
|
||||||
|
--language_idx tr \
|
||||||
|
--use_cuda true
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### model directly
|
#### model directly
|
||||||
|
|
||||||
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
|
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
|
||||||
|
@ -75,7 +115,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
out = model.inference(
|
out = model.inference(
|
||||||
|
@ -83,7 +123,6 @@ out = model.inference(
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=0.7, # Add custom parameters here
|
temperature=0.7, # Add custom parameters here
|
||||||
)
|
)
|
||||||
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||||
|
@ -112,7 +151,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -136,7 +175,7 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
A recipe for `XTTS_v1.1` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
|
A recipe for `XTTS_v2` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
|
||||||
|
|
||||||
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
|
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
|
||||||
|
|
||||||
|
@ -152,7 +191,7 @@ from TTS.tts.models.xtts import Xtts
|
||||||
# Add here the xtts_config path
|
# Add here the xtts_config path
|
||||||
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
|
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
|
||||||
# Add here the vocab file that you have used to train the model
|
# Add here the vocab file that you have used to train the model
|
||||||
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v1.1_original_model_files/vocab.json"
|
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v2_original_model_files/vocab.json"
|
||||||
# Add here the checkpoint that you want to do inference with
|
# Add here the checkpoint that you want to do inference with
|
||||||
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
|
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
|
||||||
# Add here the speaker reference
|
# Add here the speaker reference
|
||||||
|
@ -169,7 +208,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
out = model.inference(
|
out = model.inference(
|
||||||
|
@ -177,20 +216,20 @@ out = model.inference(
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=0.7, # Add custom parameters here
|
temperature=0.7, # Add custom parameters here
|
||||||
)
|
)
|
||||||
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Important resources & papers
|
## References and Acknowledgements
|
||||||
- VallE: https://arxiv.org/abs/2301.02111
|
- VallE: https://arxiv.org/abs/2301.02111
|
||||||
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
|
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
|
||||||
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
|
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
|
||||||
- Univnet: https://arxiv.org/abs/2106.07889
|
- Univnet: https://arxiv.org/abs/2106.07889
|
||||||
- Latent Diffusion:https://arxiv.org/abs/2112.10752
|
- Latent Diffusion:https://arxiv.org/abs/2112.10752
|
||||||
- DALL-E: https://arxiv.org/abs/2102.12092
|
- DALL-E: https://arxiv.org/abs/2102.12092
|
||||||
|
- Perceiver: https://arxiv.org/abs/2103.03206
|
||||||
|
|
||||||
|
|
||||||
## XttsConfig
|
## XttsConfig
|
||||||
|
|
|
@ -13,23 +13,28 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import sys\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import importlib\n",
|
"import importlib\n",
|
||||||
"import numpy as np\n",
|
"import os\n",
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from torch.utils.data import DataLoader\n",
|
|
||||||
"import soundfile as sf\n",
|
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import soundfile as sf\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from matplotlib import pylab as plt\n",
|
||||||
|
"from torch.utils.data import DataLoader\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"from TTS.config import load_config\n",
|
||||||
|
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
|
||||||
|
"from TTS.tts.datasets import load_tts_samples\n",
|
||||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||||
"from TTS.utils.audio import AudioProcessor\n",
|
|
||||||
"from TTS.config import load_config\n",
|
|
||||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
|
||||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
|
||||||
"from TTS.tts.models import setup_model\n",
|
"from TTS.tts.models import setup_model\n",
|
||||||
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||||
|
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||||
|
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||||
|
"from TTS.utils.audio import AudioProcessor\n",
|
||||||
|
"from TTS.utils.audio.numpy_transforms import quantize\n",
|
||||||
"\n",
|
"\n",
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -49,11 +54,9 @@
|
||||||
" file_name = wav_file.split('.')[0]\n",
|
" file_name = wav_file.split('.')[0]\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
|
|
||||||
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
||||||
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
||||||
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
|
" return file_name, wavq_path, mel_path"
|
||||||
" return file_name, wavq_path, mel_path, wav_path"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -65,14 +68,14 @@
|
||||||
"# Paths and configurations\n",
|
"# Paths and configurations\n",
|
||||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||||
|
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
|
||||||
"DATASET = \"ljspeech\"\n",
|
"DATASET = \"ljspeech\"\n",
|
||||||
"METADATA_FILE = \"metadata.csv\"\n",
|
"METADATA_FILE = \"metadata.csv\"\n",
|
||||||
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
||||||
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
||||||
"BATCH_SIZE = 32\n",
|
"BATCH_SIZE = 32\n",
|
||||||
"\n",
|
"\n",
|
||||||
"QUANTIZED_WAV = False\n",
|
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
|
||||||
"QUANTIZE_BIT = None\n",
|
|
||||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Check CUDA availability\n",
|
"# Check CUDA availability\n",
|
||||||
|
@ -80,10 +83,10 @@
|
||||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the configuration\n",
|
"# Load the configuration\n",
|
||||||
|
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
|
||||||
"C = load_config(CONFIG_PATH)\n",
|
"C = load_config(CONFIG_PATH)\n",
|
||||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
"ap = AudioProcessor(**C.audio)"
|
||||||
"print(C['r'])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -92,12 +95,10 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# If the vocabulary was passed, replace the default\n",
|
"# Initialize the tokenizer\n",
|
||||||
"if 'characters' in C and C['characters']:\n",
|
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
|
||||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the model\n",
|
"# Load the model\n",
|
||||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
|
||||||
"# TODO: multiple speakers\n",
|
"# TODO: multiple speakers\n",
|
||||||
"model = setup_model(C)\n",
|
"model = setup_model(C)\n",
|
||||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||||
|
@ -109,42 +110,21 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Load the preprocessor based on the dataset\n",
|
"# Load data instances\n",
|
||||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
|
||||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
"meta_data = meta_data_train + meta_data_eval\n",
|
||||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
"\n",
|
||||||
"dataset = TTSDataset(\n",
|
"dataset = TTSDataset(\n",
|
||||||
" C,\n",
|
" outputs_per_step=C[\"r\"],\n",
|
||||||
" C.text_cleaner,\n",
|
" compute_linear_spec=False,\n",
|
||||||
" False,\n",
|
" ap=ap,\n",
|
||||||
" ap,\n",
|
" samples=meta_data,\n",
|
||||||
" meta_data,\n",
|
" tokenizer=tokenizer,\n",
|
||||||
" characters=C.get('characters', None),\n",
|
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
|
||||||
" use_phonemes=C.use_phonemes,\n",
|
|
||||||
" phoneme_cache_path=C.phoneme_cache_path,\n",
|
|
||||||
" enable_eos_bos=C.enable_eos_bos_chars,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"loader = DataLoader(\n",
|
"loader = DataLoader(\n",
|
||||||
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
||||||
")\n"
|
")"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Initialize lists for storing results\n",
|
|
||||||
"file_idxs = []\n",
|
|
||||||
"metadata = []\n",
|
|
||||||
"losses = []\n",
|
|
||||||
"postnet_losses = []\n",
|
|
||||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
|
||||||
"\n",
|
|
||||||
"# Create log file\n",
|
|
||||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
|
||||||
"log_file = open(log_file_path, \"w\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -160,26 +140,33 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Initialize lists for storing results\n",
|
||||||
|
"file_idxs = []\n",
|
||||||
|
"metadata = []\n",
|
||||||
|
"losses = []\n",
|
||||||
|
"postnet_losses = []\n",
|
||||||
|
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||||
|
"\n",
|
||||||
"# Start processing with a progress bar\n",
|
"# Start processing with a progress bar\n",
|
||||||
"with torch.no_grad():\n",
|
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||||
|
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
|
||||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" # setup input data\n",
|
|
||||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
|
||||||
"\n",
|
|
||||||
" # dispatch data to GPU\n",
|
" # dispatch data to GPU\n",
|
||||||
" if use_cuda:\n",
|
" if use_cuda:\n",
|
||||||
" text_input = text_input.cuda()\n",
|
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
|
||||||
" text_lengths = text_lengths.cuda()\n",
|
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
|
||||||
" mel_input = mel_input.cuda()\n",
|
" data[\"mel\"] = data[\"mel\"].cuda()\n",
|
||||||
" mel_lengths = mel_lengths.cuda()\n",
|
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" mask = sequence_mask(text_lengths)\n",
|
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
|
||||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
|
||||||
|
" mel_outputs = outputs[\"decoder_outputs\"]\n",
|
||||||
|
" postnet_outputs = outputs[\"model_outputs\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # compute loss\n",
|
" # compute loss\n",
|
||||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" losses.append(loss.item())\n",
|
" losses.append(loss.item())\n",
|
||||||
" postnet_losses.append(loss_postnet.item())\n",
|
" postnet_losses.append(loss_postnet.item())\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -193,28 +180,27 @@
|
||||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||||
" elif C.model == \"Tacotron2\":\n",
|
" elif C.model == \"Tacotron2\":\n",
|
||||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||||
" alignments = alignments.detach().cpu().numpy()\n",
|
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if not DRY_RUN:\n",
|
" if not DRY_RUN:\n",
|
||||||
" for idx in range(text_input.shape[0]):\n",
|
" for idx in range(data[\"token_id\"].shape[0]):\n",
|
||||||
" wav_file_path = item_idx[idx]\n",
|
" wav_file_path = data[\"item_idxs\"][idx]\n",
|
||||||
" wav = ap.load_wav(wav_file_path)\n",
|
" wav = ap.load_wav(wav_file_path)\n",
|
||||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||||
" file_idxs.append(file_name)\n",
|
" file_idxs.append(file_name)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # quantize and save wav\n",
|
" # quantize and save wav\n",
|
||||||
" if QUANTIZED_WAV:\n",
|
" if QUANTIZE_BITS > 0:\n",
|
||||||
" wavq = ap.quantize(wav)\n",
|
" wavq = quantize(wav, QUANTIZE_BITS)\n",
|
||||||
" np.save(wavq_path, wavq)\n",
|
" np.save(wavq_path, wavq)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # save TTS mel\n",
|
" # save TTS mel\n",
|
||||||
" mel = postnet_outputs[idx]\n",
|
" mel = postnet_outputs[idx]\n",
|
||||||
" mel_length = mel_lengths[idx]\n",
|
" mel_length = data[\"mel_lengths\"][idx]\n",
|
||||||
" mel = mel[:mel_length, :].T\n",
|
" mel = mel[:mel_length, :].T\n",
|
||||||
" np.save(mel_path, mel)\n",
|
" np.save(mel_path, mel)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" metadata.append([wav_file_path, mel_path])\n",
|
" metadata.append([wav_file_path, mel_path])\n",
|
||||||
"\n",
|
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -224,35 +210,20 @@
|
||||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Close the log file\n",
|
|
||||||
"log_file.close()\n",
|
|
||||||
"\n",
|
|
||||||
"# For wavernn\n",
|
"# For wavernn\n",
|
||||||
"if not DRY_RUN:\n",
|
"if not DRY_RUN:\n",
|
||||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# For pwgan\n",
|
"# For pwgan\n",
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||||
" for data in metadata:\n",
|
" for wav_file_path, mel_path in metadata:\n",
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print mean losses\n",
|
"# Print mean losses\n",
|
||||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# for pwgan\n",
|
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
|
||||||
" for data in metadata:\n",
|
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -267,7 +238,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"idx = 1\n",
|
"idx = 1\n",
|
||||||
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
|
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -276,10 +247,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import soundfile as sf\n",
|
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
|
||||||
"wav, sr = sf.read(item_idx[idx])\n",
|
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
|
||||||
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
|
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
|
||||||
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
|
|
||||||
"mel_truth = ap.melspectrogram(wav)\n",
|
"mel_truth = ap.melspectrogram(wav)\n",
|
||||||
"print(mel_truth.shape)"
|
"print(mel_truth.shape)"
|
||||||
]
|
]
|
||||||
|
@ -291,7 +261,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# plot posnet output\n",
|
"# plot posnet output\n",
|
||||||
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
|
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
|
||||||
"plot_spectrogram(mel_postnet, ap)"
|
"plot_spectrogram(mel_postnet, ap)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -324,10 +294,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# postnet, decoder diff\n",
|
"# postnet, decoder diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff = mel_decoder - mel_postnet\n",
|
"mel_diff = mel_decoder - mel_postnet\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -339,10 +308,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -354,21 +322,13 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel = postnet_outputs[idx]\n",
|
"mel = postnet_outputs[idx]\n",
|
||||||
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
@ -7,7 +7,6 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||||
PROJECT_NAME = "XTTS_trainer"
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
@ -42,8 +41,8 @@ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# DVAE files
|
# DVAE files
|
||||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
|
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/dvae.pth"
|
||||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/mel_stats.pth"
|
||||||
|
|
||||||
# Set the path to the downloaded files
|
# Set the path to the downloaded files
|
||||||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||||
|
@ -56,8 +55,8 @@ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||||
|
|
||||||
|
|
||||||
# Download XTTS v1.1 checkpoint if needed
|
# Download XTTS v1.1 checkpoint if needed
|
||||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
|
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json"
|
||||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth"
|
||||||
|
|
||||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||||
|
@ -66,13 +65,15 @@ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split(
|
||||||
# download XTTS v1.1 files if needed
|
# download XTTS v1.1 files if needed
|
||||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||||
print(" > Downloading XTTS v1.1 files!")
|
print(" > Downloading XTTS v1.1 files!")
|
||||||
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
ModelManager._download_model_files(
|
||||||
|
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Training sentences generations
|
# Training sentences generations
|
||||||
SPEAKER_REFERENCE = (
|
SPEAKER_REFERENCE = [
|
||||||
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||||
)
|
]
|
||||||
LANGUAGE = config_dataset.language
|
LANGUAGE = config_dataset.language
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,12 +94,9 @@ def main():
|
||||||
gpt_num_audio_tokens=8194,
|
gpt_num_audio_tokens=8194,
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
|
||||||
)
|
)
|
||||||
# define audio config
|
# define audio config
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
|
||||||
# training parameters config
|
# training parameters config
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
# Logging parameters
|
||||||
|
RUN_NAME = "GPT_XTTS_v2.0_LJSpeech_FT"
|
||||||
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
DASHBOARD_LOGGER = "tensorboard"
|
||||||
|
LOGGER_URI = None
|
||||||
|
|
||||||
|
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||||
|
OUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run", "training")
|
||||||
|
|
||||||
|
# Training Parameters
|
||||||
|
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||||
|
START_WITH_EVAL = True # if True it will star with evaluation
|
||||||
|
BATCH_SIZE = 3 # set here the batch size
|
||||||
|
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
||||||
|
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||||
|
|
||||||
|
# Define here the dataset that you want to use for the fine-tuning on.
|
||||||
|
config_dataset = BaseDatasetConfig(
|
||||||
|
formatter="ljspeech",
|
||||||
|
dataset_name="ljspeech",
|
||||||
|
path="/raid/datasets/LJSpeech-1.1_24khz/",
|
||||||
|
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add here the configs of the datasets
|
||||||
|
DATASETS_CONFIG_LIST = [config_dataset]
|
||||||
|
|
||||||
|
# Define the path where XTTS v2.0.1 files will be downloaded
|
||||||
|
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||||
|
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# DVAE files
|
||||||
|
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||||
|
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||||
|
|
||||||
|
# Set the path to the downloaded files
|
||||||
|
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
|
||||||
|
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
|
||||||
|
|
||||||
|
# download DVAE files if needed
|
||||||
|
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||||
|
print(" > Downloading DVAE files!")
|
||||||
|
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Download XTTS v2.0 checkpoint if needed
|
||||||
|
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||||
|
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
|
||||||
|
|
||||||
|
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||||
|
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
|
||||||
|
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
|
||||||
|
|
||||||
|
# download XTTS v2.0 files if needed
|
||||||
|
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||||
|
print(" > Downloading XTTS v2.0 files!")
|
||||||
|
ModelManager._download_model_files(
|
||||||
|
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Training sentences generations
|
||||||
|
SPEAKER_REFERENCE = [
|
||||||
|
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||||
|
]
|
||||||
|
LANGUAGE = config_dataset.language
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# init args and config
|
||||||
|
model_args = GPTArgs(
|
||||||
|
max_conditioning_length=132300, # 6 secs
|
||||||
|
min_conditioning_length=66150, # 3 secs
|
||||||
|
debug_loading_failures=False,
|
||||||
|
max_wav_length=255995, # ~11.6 seconds
|
||||||
|
max_text_length=200,
|
||||||
|
mel_norm_file=MEL_NORM_FILE,
|
||||||
|
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||||
|
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||||
|
tokenizer_file=TOKENIZER_FILE,
|
||||||
|
gpt_num_audio_tokens=1026,
|
||||||
|
gpt_start_audio_token=1024,
|
||||||
|
gpt_stop_audio_token=1025,
|
||||||
|
gpt_use_masking_gt_prompt_approach=True,
|
||||||
|
gpt_use_perceiver_resampler=True,
|
||||||
|
)
|
||||||
|
# define audio config
|
||||||
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
|
# training parameters config
|
||||||
|
config = GPTTrainerConfig(
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model_args=model_args,
|
||||||
|
run_name=RUN_NAME,
|
||||||
|
project_name=PROJECT_NAME,
|
||||||
|
run_description="""
|
||||||
|
GPT XTTS training
|
||||||
|
""",
|
||||||
|
dashboard_logger=DASHBOARD_LOGGER,
|
||||||
|
logger_uri=LOGGER_URI,
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
batch_group_size=48,
|
||||||
|
eval_batch_size=BATCH_SIZE,
|
||||||
|
num_loader_workers=8,
|
||||||
|
eval_split_max_size=256,
|
||||||
|
print_step=50,
|
||||||
|
plot_step=100,
|
||||||
|
log_model_step=1000,
|
||||||
|
save_step=10000,
|
||||||
|
save_n_checkpoints=1,
|
||||||
|
save_checkpoints=True,
|
||||||
|
# target_loss="loss",
|
||||||
|
print_eval=False,
|
||||||
|
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||||
|
optimizer="AdamW",
|
||||||
|
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||||
|
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||||
|
lr=5e-06, # learning rate
|
||||||
|
lr_scheduler="MultiStepLR",
|
||||||
|
# it was adjusted accordly for the new step scheme
|
||||||
|
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||||
|
test_sentences=[
|
||||||
|
{
|
||||||
|
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"speaker_wav": SPEAKER_REFERENCE,
|
||||||
|
"language": LANGUAGE,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "This cake is great. It's so delicious and moist.",
|
||||||
|
"speaker_wav": SPEAKER_REFERENCE,
|
||||||
|
"language": LANGUAGE,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the model from config
|
||||||
|
model = GPTTrainer.init_from_config(config)
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
DATASETS_CONFIG_LIST,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainerArgs(
|
||||||
|
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||||
|
skip_train_epoch=False,
|
||||||
|
start_with_eval=START_WITH_EVAL,
|
||||||
|
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
)
|
||||||
|
trainer.fit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,38 +1,40 @@
|
||||||
# core deps
|
# core deps
|
||||||
numpy==1.22.0;python_version<="3.10"
|
numpy==1.22.0;python_version<="3.10"
|
||||||
numpy==1.24.3;python_version>"3.10"
|
numpy>=1.24.3;python_version>"3.10"
|
||||||
cython==0.29.30
|
cython>=0.29.30
|
||||||
scipy>=1.11.2
|
scipy>=1.11.2
|
||||||
torch>=1.7
|
torch>=2.1
|
||||||
torchaudio
|
torchaudio
|
||||||
soundfile==0.12.*
|
soundfile>=0.12.0
|
||||||
librosa==0.10.*
|
librosa>=0.10.0
|
||||||
scikit-learn==1.3.0
|
scikit-learn>=1.3.0
|
||||||
numba==0.55.1;python_version<"3.9"
|
numba==0.55.1;python_version<"3.9"
|
||||||
numba==0.57.0;python_version>="3.9"
|
numba>=0.57.0;python_version>="3.9"
|
||||||
inflect==5.6.*
|
inflect>=5.6.0
|
||||||
tqdm==4.64.*
|
tqdm>=4.64.1
|
||||||
anyascii==0.3.*
|
anyascii>=0.3.0
|
||||||
pyyaml==6.*
|
pyyaml>=6.0
|
||||||
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
|
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||||
aiohttp==3.8.*
|
aiohttp>=3.8.1
|
||||||
packaging==23.1
|
packaging>=23.1
|
||||||
# deps for examples
|
# deps for examples
|
||||||
flask==2.*
|
flask>=2.0.1
|
||||||
# deps for inference
|
# deps for inference
|
||||||
pysbd==0.3.4
|
pysbd>=0.3.4
|
||||||
# deps for notebooks
|
# deps for notebooks
|
||||||
umap-learn==0.5.*
|
umap-learn>=0.5.1
|
||||||
pandas>=1.4,<2.0
|
pandas>=1.4,<2.0
|
||||||
# deps for training
|
# deps for training
|
||||||
matplotlib==3.7.*
|
matplotlib>=3.7.0
|
||||||
# coqui stack
|
# coqui stack
|
||||||
trainer
|
trainer>=0.0.32
|
||||||
# config management
|
# config management
|
||||||
coqpit>=0.0.16
|
coqpit>=0.0.16
|
||||||
# chinese g2p deps
|
# chinese g2p deps
|
||||||
jieba
|
jieba
|
||||||
pypinyin
|
pypinyin
|
||||||
|
# korean
|
||||||
|
hangul_romanize
|
||||||
# gruut+supported langs
|
# gruut+supported langs
|
||||||
gruut[de,es,fr]==2.2.3
|
gruut[de,es,fr]==2.2.3
|
||||||
# deps for korean
|
# deps for korean
|
||||||
|
@ -44,10 +46,11 @@ bangla
|
||||||
bnnumerizer
|
bnnumerizer
|
||||||
bnunicodenormalizer
|
bnunicodenormalizer
|
||||||
#deps for tortoise
|
#deps for tortoise
|
||||||
k_diffusion
|
einops>=0.6.0
|
||||||
einops==0.6.*
|
transformers>=4.33.0
|
||||||
transformers==4.33.*
|
|
||||||
#deps for bark
|
#deps for bark
|
||||||
encodec==0.1.*
|
encodec>=0.1.1
|
||||||
# deps for XTTS
|
# deps for XTTS
|
||||||
unidecode==1.3.*
|
unidecode>=1.3.2
|
||||||
|
num2words
|
||||||
|
spacy[ja]>=3
|
|
@ -22,7 +22,4 @@ def test_synthesize():
|
||||||
)
|
)
|
||||||
|
|
||||||
# test pipe_out command
|
# test pipe_out command
|
||||||
run_cli(
|
run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')
|
||||||
'tts --text "test." --pipe_out '
|
|
||||||
f'--out_path "{output_path}" | aplay'
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,11 +3,11 @@ import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import save_checkpoint
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
from TTS.encoder.utils.io import save_checkpoint
|
|
||||||
from TTS.tts.utils.managers import EmbeddingManager
|
from TTS.tts.utils.managers import EmbeddingManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
|
||||||
|
|
||||||
# create a dummy speaker encoder
|
# create a dummy speaker encoder
|
||||||
model = setup_encoder_model(config)
|
model = setup_encoder_model(config)
|
||||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||||
|
|
||||||
# load audio processor and speaker encoder
|
# load audio processor and speaker encoder
|
||||||
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||||
|
|
|
@ -3,11 +3,11 @@ import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import save_checkpoint
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
from TTS.encoder.utils.io import save_checkpoint
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
|
|
||||||
# create a dummy speaker encoder
|
# create a dummy speaker encoder
|
||||||
model = setup_encoder_model(config)
|
model = setup_encoder_model(config)
|
||||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||||
|
|
||||||
# load audio processor and speaker encoder
|
# load audio processor and speaker encoder
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from trainer.io import save_checkpoint
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.utils.io import save_checkpoint
|
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
from TTS.config import BaseAudioConfig
|
from TTS.config import BaseAudioConfig
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import stft
|
||||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
|
@ -21,7 +22,7 @@ def test_torch_stft():
|
||||||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||||
# librosa stft
|
# librosa stft
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
|
||||||
# torch stft
|
# torch stft
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
M_torch = torch_stft(wav)
|
M_torch = torch_stft(wav)
|
||||||
|
|
|
@ -60,7 +60,9 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
|
||||||
|
|
||||||
|
|
||||||
# Training sentences generations
|
# Training sentences generations
|
||||||
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
SPEAKER_REFERENCE = [
|
||||||
|
"tests/data/ljspeech/wavs/LJ001-0002.wav"
|
||||||
|
] # speaker reference to be used in training test sentences
|
||||||
LANGUAGE = config_dataset.language
|
LANGUAGE = config_dataset.language
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,9 +89,7 @@ model_args = GPTArgs(
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
)
|
)
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
|
@ -0,0 +1,163 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
from tests import get_tests_output_path
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
|
|
||||||
|
config_dataset = BaseDatasetConfig(
|
||||||
|
formatter="ljspeech",
|
||||||
|
dataset_name="ljspeech",
|
||||||
|
path="tests/data/ljspeech/",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASETS_CONFIG_LIST = [config_dataset]
|
||||||
|
|
||||||
|
# Logging parameters
|
||||||
|
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||||
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
DASHBOARD_LOGGER = "tensorboard"
|
||||||
|
LOGGER_URI = None
|
||||||
|
|
||||||
|
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
||||||
|
os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
# Create DVAE checkpoint and mel_norms on test time
|
||||||
|
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||||
|
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||||
|
# Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||||
|
MEL_NORM_FILE = os.path.join(OUT_PATH, "mel_stats.pth")
|
||||||
|
dvae = DiscreteVAE(
|
||||||
|
channels=80,
|
||||||
|
normalization=None,
|
||||||
|
positional_dims=1,
|
||||||
|
num_tokens=8192,
|
||||||
|
codebook_dim=512,
|
||||||
|
hidden_dim=512,
|
||||||
|
num_resnet_blocks=3,
|
||||||
|
kernel_size=3,
|
||||||
|
num_layers=2,
|
||||||
|
use_transposed_convs=False,
|
||||||
|
)
|
||||||
|
torch.save(dvae.state_dict(), DVAE_CHECKPOINT)
|
||||||
|
mel_stats = torch.ones(80)
|
||||||
|
torch.save(mel_stats, MEL_NORM_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||||
|
TOKENIZER_FILE = "tests/inputs/xtts_vocab.json" # vocab.json file
|
||||||
|
XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
|
||||||
|
|
||||||
|
|
||||||
|
# Training sentences generations
|
||||||
|
SPEAKER_REFERENCE = [
|
||||||
|
"tests/data/ljspeech/wavs/LJ001-0002.wav"
|
||||||
|
] # speaker reference to be used in training test sentences
|
||||||
|
LANGUAGE = config_dataset.language
|
||||||
|
|
||||||
|
|
||||||
|
# Training Parameters
|
||||||
|
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||||
|
START_WITH_EVAL = False # if True it will star with evaluation
|
||||||
|
BATCH_SIZE = 2 # set here the batch size
|
||||||
|
GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps
|
||||||
|
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||||
|
|
||||||
|
|
||||||
|
# init args and config
|
||||||
|
model_args = GPTArgs(
|
||||||
|
max_conditioning_length=132300, # 6 secs
|
||||||
|
min_conditioning_length=66150, # 3 secs
|
||||||
|
debug_loading_failures=False,
|
||||||
|
max_wav_length=255995, # ~11.6 seconds
|
||||||
|
max_text_length=200,
|
||||||
|
mel_norm_file=MEL_NORM_FILE,
|
||||||
|
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||||
|
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||||
|
tokenizer_file=TOKENIZER_FILE,
|
||||||
|
gpt_num_audio_tokens=8194,
|
||||||
|
gpt_start_audio_token=8192,
|
||||||
|
gpt_stop_audio_token=8193,
|
||||||
|
gpt_use_masking_gt_prompt_approach=True,
|
||||||
|
gpt_use_perceiver_resampler=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
|
|
||||||
|
config = GPTTrainerConfig(
|
||||||
|
epochs=1,
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model_args=model_args,
|
||||||
|
run_name=RUN_NAME,
|
||||||
|
project_name=PROJECT_NAME,
|
||||||
|
run_description="GPT XTTS training",
|
||||||
|
dashboard_logger=DASHBOARD_LOGGER,
|
||||||
|
logger_uri=LOGGER_URI,
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
batch_group_size=48,
|
||||||
|
eval_batch_size=BATCH_SIZE,
|
||||||
|
num_loader_workers=8,
|
||||||
|
eval_split_max_size=256,
|
||||||
|
print_step=50,
|
||||||
|
plot_step=100,
|
||||||
|
log_model_step=1000,
|
||||||
|
save_step=10000,
|
||||||
|
save_n_checkpoints=1,
|
||||||
|
save_checkpoints=True,
|
||||||
|
# target_loss="loss",
|
||||||
|
print_eval=False,
|
||||||
|
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||||
|
optimizer="AdamW",
|
||||||
|
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||||
|
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||||
|
lr=5e-06, # learning rate
|
||||||
|
lr_scheduler="MultiStepLR",
|
||||||
|
# it was adjusted accordly for the new step scheme
|
||||||
|
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||||
|
test_sentences=[
|
||||||
|
{
|
||||||
|
"text": "This cake is great. It's so delicious and moist.",
|
||||||
|
"speaker_wav": SPEAKER_REFERENCE,
|
||||||
|
"language": LANGUAGE,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the model from config
|
||||||
|
model = GPTTrainer.init_from_config(config)
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
DATASETS_CONFIG_LIST,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainerArgs(
|
||||||
|
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||||
|
skip_train_epoch=False,
|
||||||
|
start_with_eval=True,
|
||||||
|
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
)
|
||||||
|
trainer.fit()
|
||||||
|
|
||||||
|
# remove output path
|
||||||
|
shutil.rmtree(OUT_PATH)
|
|
@ -14,8 +14,8 @@ from TTS.utils.manage import ModelManager
|
||||||
MODELS_WITH_SEP_TESTS = [
|
MODELS_WITH_SEP_TESTS = [
|
||||||
"tts_models/multilingual/multi-dataset/bark",
|
"tts_models/multilingual/multi-dataset/bark",
|
||||||
"tts_models/en/multi-dataset/tortoise-v2",
|
"tts_models/en/multi-dataset/tortoise-v2",
|
||||||
"tts_models/multilingual/multi-dataset/xtts_v1",
|
|
||||||
"tts_models/multilingual/multi-dataset/xtts_v1.1",
|
"tts_models/multilingual/multi-dataset/xtts_v1.1",
|
||||||
|
"tts_models/multilingual/multi-dataset/xtts_v2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,14 +82,14 @@ def test_xtts():
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
run_cli(
|
run_cli(
|
||||||
"yes | "
|
"yes | "
|
||||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run_cli(
|
run_cli(
|
||||||
"yes | "
|
"yes | "
|
||||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||||
)
|
)
|
||||||
|
@ -100,8 +100,10 @@ def test_xtts_streaming():
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.models.xtts import Xtts
|
from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
|
||||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
|
||||||
|
speaker_wav.append(speaker_wav_2)
|
||||||
|
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1.1")
|
||||||
config = XttsConfig()
|
config = XttsConfig()
|
||||||
config.load_json(os.path.join(model_path, "config.json"))
|
config.load_json(os.path.join(model_path, "config.json"))
|
||||||
model = Xtts.init_from_config(config)
|
model = Xtts.init_from_config(config)
|
||||||
|
@ -109,7 +111,7 @@ def test_xtts_streaming():
|
||||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
chunks = model.inference_stream(
|
chunks = model.inference_stream(
|
||||||
|
@ -126,6 +128,87 @@ def test_xtts_streaming():
|
||||||
assert len(wav_chuncks) > 1
|
assert len(wav_chuncks) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_xtts_v2():
|
||||||
|
"""XTTS is too big to run on github actions. We need to test it locally"""
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||||
|
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
|
||||||
|
use_gpu = torch.cuda.is_available()
|
||||||
|
if use_gpu:
|
||||||
|
run_cli(
|
||||||
|
"yes | "
|
||||||
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||||
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||||
|
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
run_cli(
|
||||||
|
"yes | "
|
||||||
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||||
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||||
|
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_xtts_v2_streaming():
|
||||||
|
"""Testing the new inference_stream method"""
|
||||||
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
|
from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
|
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
|
||||||
|
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
|
||||||
|
config = XttsConfig()
|
||||||
|
config.load_json(os.path.join(model_path, "config.json"))
|
||||||
|
model = Xtts.init_from_config(config)
|
||||||
|
model.load_checkpoint(config, checkpoint_dir=model_path)
|
||||||
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
|
print("Computing speaker latents...")
|
||||||
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
|
print("Inference...")
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if i == 0:
|
||||||
|
assert chunk.shape[-1] > 5000
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
assert len(wav_chuncks) > 1
|
||||||
|
normal_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=1.5,
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
fast_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=0.66,
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
slow_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
assert slow_len > normal_len
|
||||||
|
assert normal_len > fast_len
|
||||||
|
|
||||||
|
|
||||||
def test_tortoise():
|
def test_tortoise():
|
||||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
use_gpu = torch.cuda.is_available()
|
use_gpu = torch.cuda.is_available()
|
||||||
|
|
Loading…
Reference in New Issue