Merge pull request #717 from coqui-ai/dev

v0.2.0
This commit is contained in:
Eren Gölge 2021-08-11 10:08:04 +02:00 committed by GitHub
commit 01a2b0b5c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
94 changed files with 3610 additions and 657 deletions

View File

@ -1,10 +1,15 @@
---
name: 'Contribution Guideline '
about: Refer to Contirbution Guideline
title: ''
labels: ''
assignees: ''
# Pull request guidelines
---
Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support!
👐 Please check our [CONTRIBUTION GUIDELINE](https://github.com/coqui-ai/TTS#contribution-guidelines).
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS).
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS):
- Protects you, Coqui, and the users of the code.
- Does not change your rights to use your contributions for any purpose.
- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.

5
.gitignore vendored
View File

@ -136,6 +136,7 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
temp_build/*
recipes/WIP/*
recipes/ljspeech/LJSpeech-1.1/*
recipes/ljspeech/tacotron2-DDC/LJSpeech-1.1/*
events.out*
old_configs/*
model_importers/*
@ -152,4 +153,6 @@ output.wav
tts_output.wav
deps.json
speakers.json
internal/*
internal/*
*_pitch.npy
*_phoneme.npy

View File

@ -4,7 +4,7 @@
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
target_dirs := tests TTS notebooks
target_dirs := tests TTS notebooks recipes
test_all: ## run tests and don't stop on an error.
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id

View File

@ -73,10 +73,13 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
### Attention Methods
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
- Graves Attention: [paper](https://arxiv.org/abs/1907.09006)
- Graves Attention: [paper](https://arxiv.org/abs/1910.10288)
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
- Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf)

View File

@ -1,7 +1,7 @@
{
"tts_models":{
"en":{
"ek1":{
"tts_models": {
"en": {
"ek1": {
"tacotron2": {
"description": "EK1 en-rp tacotron2 by NMStoker",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
@ -9,7 +9,7 @@
"commit": "c802255"
}
},
"ljspeech":{
"ljspeech": {
"tacotron2-DDC": {
"description": "Tacotron2 with Double Decoder Consistency.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip",
@ -17,9 +17,18 @@
"commit": "bae2ad0f",
"author": "Eren Gölge @erogol",
"license": "",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"glow-tts":{
"tacotron2-DDC_ph": {
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
"commit": "3900448",
"author": "Eren Gölge @erogol",
"license": "",
"contact": "egolge@coqui.com"
},
"glow-tts": {
"description": "",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip",
"stats_file": null,
@ -27,7 +36,7 @@
"commit": "",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"tacotron2-DCA": {
"description": "",
@ -36,19 +45,28 @@
"commit": "",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"speedy-speech-wn":{
"speedy-speech-wn": {
"description": "Speedy Speech model with wavenet decoder.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
"commit": "77b6145",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren Gölge @erogol",
"license": "",
"contact": "egolge@coqui.com"
}
},
"vctk":{
"vctk": {
"sc-glow-tts": {
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
@ -56,12 +74,19 @@
"commit": "b531fa69",
"author": "Edresson Casanova",
"license": "",
"contact":""
"contact": ""
},
"vits": {
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--vctk--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren @erogol",
"license": "",
"contact": "egolge@coqui.ai"
}
},
"sam":{
"sam": {
"tacotron-DDC": {
"description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip",
@ -69,46 +94,47 @@
"commit": "bae2ad0f",
"author": "Eren Gölge @erogol",
"license": "",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
}
}
},
"es":{
"mai":{
"tacotron2-DDC":{
"es": {
"mai": {
"tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"commit": "",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
}
}
},
"fr":{
"mai":{
"tacotron2-DDC":{
"fr": {
"mai": {
"tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"commit": "",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
}
}
},
"zh-CN":{
"baker":{
"tacotron2-DDC-GST":{
"zh-CN": {
"baker": {
"tacotron2-DDC-GST": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
"commit": "unknown",
"author": "@kirianguiller"
"author": "@kirianguiller",
"default_vocoder": null
}
}
},
"nl":{
"mai":{
"tacotron2-DDC":{
"nl": {
"mai": {
"tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip",
"author": "@r-dh",
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
@ -117,20 +143,9 @@
}
}
},
"ru":{
"ruslan":{
"tacotron2-DDC":{
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--ru--ruslan--tacotron2-DDC.zip",
"author": "@erogol",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"license":"",
"contact": "egolge@coqui.com"
}
}
},
"de":{
"thorsten":{
"tacotron2-DCA":{
"de": {
"thorsten": {
"tacotron2-DCA": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
"default_vocoder": "vocoder_models/de/thorsten/fullband-melgan",
"author": "@thorstenMueller",
@ -138,9 +153,9 @@
}
}
},
"ja":{
"kokoro":{
"tacotron2-DDC":{
"ja": {
"kokoro": {
"tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
@ -150,54 +165,62 @@
}
}
},
"vocoder_models":{
"universal":{
"libri-tts":{
"wavegrad":{
"vocoder_models": {
"universal": {
"libri-tts": {
"wavegrad": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip",
"commit": "ea976b0",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"fullband-melgan":{
"fullband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip",
"commit": "4132240",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
}
}
},
"en": {
"ek1":{
"ek1": {
"wavegrad": {
"description": "EK1 en-rp wavegrad by NMStoker",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip",
"commit": "c802255"
}
},
"ljspeech":{
"multiband-melgan":{
"ljspeech": {
"multiband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip",
"commit": "ea976b0",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact":"egolge@coqui.com"
"contact": "egolge@coqui.com"
},
"hifigan_v2":{
"hifigan_v2": {
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip",
"commit": "bae2ad0f",
"author": "@erogol",
"license": "",
"contact": "egolge@coqui.ai"
},
"univnet": {
"description": "UnivNet model trained on LJSpeech to complement the TacotronDDC_ph model.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--en--ljspeech--univnet.zip",
"commit": "3900448",
"author": "Eren @erogol",
"license": "",
"contact": "egolge@coqui.ai"
}
},
"vctk":{
"hifigan_v2":{
"vctk": {
"hifigan_v2": {
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
"commit": "2f07160",
"author": "Edresson Casanova",
"license": "",
@ -205,9 +228,9 @@
}
},
"sam": {
"hifigan_v2":{
"hifigan_v2": {
"description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC",
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
"commit": "2f07160",
"author": "Eren Gölge @erogol",
"license": "",
@ -215,28 +238,38 @@
}
}
},
"nl":{
"mai":{
"parallel-wavegan":{
"nl": {
"mai": {
"parallel-wavegan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip",
"author": "@r-dh",
"commit": "unknown"
}
}
},
"de":{
"thorsten":{
"wavegrad":{
"de": {
"thorsten": {
"wavegrad": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip",
"author": "@thorstenMueller",
"commit": "unknown"
},
"fullband-melgan":{
"fullband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip",
"author": "@thorstenMueller",
"commit": "unknown"
}
}
},
"ja": {
"kokoro": {
"hifigan_v1": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--ja--kokoro--hifigan_v1.zip",
"description": "HifiGAN model trained for kokoro dataset by @kaiidams",
"author": "@kaiidams",
"commit": "3900448"
}
}
}
}
}
}

View File

@ -6,7 +6,7 @@ import numpy as np
import tensorflow as tf
import torch
from TTS.utils.io import load_config
from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf,
convert_tf_name,
@ -33,7 +33,7 @@ num_speakers = 0
# init torch model
model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()

View File

@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
@ -32,7 +32,7 @@ num_speakers = 0
# init torch model
model = setup_model(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)

View File

@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec
use_cuda = torch.cuda.is_available()
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c)
# restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
if use_cuda:

View File

@ -208,7 +208,7 @@ def main():
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custome model paths
# CASE3: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path

View File

@ -17,6 +17,7 @@ from TTS.trainer import init_training
from TTS.tts.datasets import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update
@ -115,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
"step_time": step_time,
"avg_loader_time": avg_loader_time,
}
tb_logger.tb_train_epoch_stats(global_step, train_stats)
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
# FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
}
tb_logger.tb_train_figures(global_step, figures)
dashboard_logger.train_figures(global_step, figures)
if global_step % c.print_step == 0:
print(
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path:
checkpoint = torch.load(args.restore_path)
checkpoint = load_fsspec(args.restore_path)
try:
model.load_state_dict(checkpoint["model"])
@ -207,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
try:
main(args)

View File

@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
def main():
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
trainer.fit()

View File

@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
def main():
try:
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()
except KeyboardInterrupt:
remove_experiment_folder(output_path)

View File

@ -3,6 +3,7 @@ import os
import re
from typing import Dict
import fsspec
import yaml
from coqpit import Coqpit
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path):
"""for backward compat."""
# fallback to json
with open(json_path, "r", encoding="utf-8") as f:
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r"\\\n", "", input_str)
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f:
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
elif ext == ".json":
try:
with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read()
data = json.loads(input_str)
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.decoder.JSONDecodeError:
# backwards compat.
data = read_json_with_comments(config_path)

View File

@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
do_trim_silence (bool):
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
power (float):
@ -79,7 +83,7 @@ class BaseAudioConfig(Coqpit):
preemphasis: float = 0.0
ref_level_db: int = 20
do_sound_norm: bool = False
log_func = "np.log10"
log_func: str = "np.log10"
# silence trimming
do_trim_silence: bool = True
trim_db: int = 45
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
mel_fmin: float = 0.0
mel_fmax: float = None
spec_gain: int = 20
do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True
# normalization params
signal_norm: bool = True
min_level_db: int = -100
@ -182,51 +188,87 @@ class BaseTrainingConfig(Coqpit):
Args:
model (str):
Name of the model that is used in the training.
run_name (str):
Name of the experiment. This prefixes the output folder name.
run_description (str):
Short description of the experiment.
epochs (int):
Number training epochs. Defaults to 10000.
batch_size (int):
Training batch size.
eval_batch_size (int):
Validation batch size.
mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time.
print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```.
print_step (int):
Number of steps required to print the next training log.
tb_plot_step (int):
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard.
tb_model_param_stats (bool):
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```.
project_name (str):
Name of the project. Defaults to config.model
wandb_entity (str):
Name of W&B entity/team. Enables collaboration across a team or org.
log_model_step (int):
Number of steps required to log a checkpoint as W&B artifact
save_step (int):ipt
Number of steps required to save the next checkpoint.
checkpoint (bool):
Enable / Disable checkpointing.
keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```.
keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000.
num_loader_workers (int):
Number of workers for training time dataloader.
num_eval_loader_workers (int):
Number of workers for evaluation time dataloader.
output_path (str):
Path for training output folder. The nonexist part of the given path is created automatically.
All training outputs are saved there.
Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
"""
model: str = None
@ -237,14 +279,19 @@ class BaseTrainingConfig(Coqpit):
batch_size: int = None
eval_batch_size: int = None
mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params
run_eval: bool = True
test_delay_epochs: int = 0
print_eval: bool = False
# logging
dashboard_logger: str = "tensorboard"
print_step: int = 25
tb_plot_step: int = 100
tb_model_param_stats: bool = False
plot_step: int = 100
model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
# checkpointing
save_step: int = 10000
checkpoint: bool = True

View File

@ -103,8 +103,8 @@ synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
)
use_multi_speaker = synthesizer.tts_model.speaker_manager is not None and synthesizer.tts_model.num_speakers > 1
speaker_manager = synthesizer.tts_model.speaker_manager if hasattr(synthesizer.tts_model, "speaker_manager") else None
use_multi_speaker = hasattr(synthesizer.tts_model, "speaker_manager") and synthesizer.tts_model.num_speakers > 1
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
# TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__)

View File

@ -2,6 +2,8 @@ import numpy as np
import torch
from torch import nn
from TTS.utils.io import load_fsspec
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()

View File

@ -2,6 +2,8 @@ import numpy as np
import torch
import torch.nn as nn
from TTS.utils.io import load_fsspec
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()

View File

@ -6,11 +6,11 @@ import re
from multiprocessing import Manager
import numpy as np
import torch
from scipy import signal
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class Storage(object):
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,7 +1,7 @@
import datetime
import os
import torch
from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import glob
import importlib
import logging
import multiprocessing
import os
import platform
import re
@ -12,7 +12,9 @@ import traceback
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
import fsspec
import torch
from coqpit import Coqpit
from torch import nn
@ -29,18 +31,20 @@ from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (
KeepAverage,
count_parameters,
create_experiment_folder,
get_experiment_folder_path,
get_git_branch,
remove_experiment_folder,
set_init_dict,
to_cuda,
)
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
multiprocessing.set_start_method("fork")
if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973
import resource
@ -48,6 +52,7 @@ if platform.system() != "Windows":
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
if is_apex_available():
from apex import amp
@ -87,7 +92,7 @@ class Trainer:
config: Coqpit,
output_path: str,
c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None,
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
model: nn.Module = None,
cudnn_benchmark: bool = False,
) -> None:
@ -112,7 +117,7 @@ class Trainer:
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
console logger is used. Defaults to None.
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
Defaults to None.
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
@ -134,8 +139,8 @@ class Trainer:
Running trainer on a config.
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger)
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
>>> trainer.fit()
TODO:
@ -148,22 +153,24 @@ class Trainer:
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, tb_logger = process_args(args)
config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path
self.args = args
self.config = config
self.config.output_log_path = output_path
# init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None:
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else:
self.tb_logger = tb_logger
self.dashboard_logger = dashboard_logger
if self.dashboard_logger is None:
self.dashboard_logger = init_logger(config)
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
@ -173,7 +180,6 @@ class Trainer:
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None
self.keep_avg_eval = None
@ -184,7 +190,7 @@ class Trainer:
# init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict())
# load dataset samples
# load data samples
# TODO: refactor this
if "datasets" in self.config:
# load data for `tts` models
@ -205,6 +211,10 @@ class Trainer:
else:
self.model = self.get_model(self.config)
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
# setup criterion
self.criterion = self.get_criterion(self.model)
@ -274,9 +284,9 @@ class Trainer:
"""
# TODO: better model setup
try:
model = setup_tts_model(config)
except ModuleNotFoundError:
model = setup_vocoder_model(config)
except ModuleNotFoundError:
model = setup_tts_model(config)
return model
def restore_model(
@ -309,7 +319,7 @@ class Trainer:
return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
checkpoint = load_fsspec(restore_path)
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
@ -417,7 +427,7 @@ class Trainer:
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
config: Coqpit,
optimizer_idx: int = None,
) -> Tuple[Dict, Dict, int, torch.Tensor]:
) -> Tuple[Dict, Dict, int]:
"""Perform a forward - backward pass and run the optimizer.
Args:
@ -426,7 +436,7 @@ class Trainer:
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
scaler (AMPScaler): AMP scaler.
criterion (nn.Module): Model's criterion.
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
config (Coqpit): Model config.
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
@ -436,6 +446,7 @@ class Trainer:
Returns:
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
"""
step_start_time = time.time()
# zero-out optimizer
optimizer.zero_grad()
@ -448,11 +459,11 @@ class Trainer:
# skip the rest
if outputs is None:
step_time = time.time() - step_start_time
return None, {}, step_time, 0
return None, {}, step_time
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
# set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None:
@ -463,7 +474,6 @@ class Trainer:
else:
grad_clip = 0.0 # meaning no gradient clipping
# TODO: compute grad norm
if grad_clip <= 0:
grad_norm = 0
@ -474,15 +484,17 @@ class Trainer:
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer),
grad_clip,
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
else:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
scale_prev = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
@ -491,13 +503,13 @@ class Trainer:
# main model optimizer step
loss_dict["loss"].backward()
if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
optimizer.step()
step_time = time.time() - step_start_time
# setup lr
if scheduler is not None and update_lr_scheduler:
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
scheduler.step()
# detach losses
@ -505,7 +517,9 @@ class Trainer:
if optimizer_idx is not None:
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
return outputs, loss_dict, step_time, grad_norm
else:
loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time
@staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict:
@ -544,11 +558,10 @@ class Trainer:
# conteainers to hold model outputs and losses for each optimizer.
outputs_per_optimizer = None
log_dict = {}
loss_dict = {}
if not isinstance(self.optimizer, list):
# training with a single optimizer
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
)
loss_dict.update(loss_dict_new)
@ -560,25 +573,36 @@ class Trainer:
criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None
scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
)
# skip the rest if the model returns None
total_step_time += step_time
outputs_per_optimizer[idx] = outputs
# merge loss_dicts from each optimizer
# rename duplicates with the optimizer idx
# if None, model skipped this optimizer
if loss_dict_new is not None:
loss_dict.update(loss_dict_new)
for k, v in loss_dict_new.items():
if k in loss_dict:
loss_dict[f"{k}-{idx}"] = v
else:
loss_dict[k] = v
step_time = total_step_time
outputs = outputs_per_optimizer
# update avg stats
# update avg runtime stats
keep_avg_update = dict()
for key, value in log_dict.items():
keep_avg_update["avg_" + key] = value
keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
# print training progress
if self.total_steps_done % self.config.print_step == 0:
# log learning rates
@ -590,33 +614,27 @@ class Trainer:
else:
current_lr = self.optimizer.param_groups[0]["lr"]
lrs = {"current_lr": current_lr}
log_dict.update(lrs)
if grad_norm > 0:
log_dict.update({"grad_norm": grad_norm})
# log run-time stats
log_dict.update(
loss_dict.update(
{
"step_time": round(step_time, 4),
"loader_time": round(loader_time, 4),
}
)
self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
)
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = log_dict
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.plot_step == 0:
self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint:
# checkpoint the model
model_loss = (
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
)
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
save_checkpoint(
self.config,
self.model,
@ -625,8 +643,14 @@ class Trainer:
self.total_steps_done,
self.epochs_done,
self.output_path,
model_loss=model_loss,
model_loss=target_avg_loss,
)
if self.total_steps_done % self.config.log_model_step == 0:
# log checkpoint as artifact
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
# training visualizations
figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
@ -634,11 +658,13 @@ class Trainer:
elif hasattr(self.model, "train_log"):
figures, audios = self.model.train_log(self.ap, batch, outputs)
if figures is not None:
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.dashboard_logger.train_figures(self.total_steps_done, figures)
if audios is not None:
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.total_steps_done += 1
self.callbacks.on_train_step_end()
self.dashboard_logger.flush()
return outputs, loss_dict
def train_epoch(self) -> None:
@ -663,9 +689,17 @@ class Trainer:
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.model_param_stats:
self.logger.model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
if scheduler is not None:
scheduler.step()
else:
self.scheduler.step()
@staticmethod
def _model_eval_step(
@ -701,19 +735,22 @@ class Trainer:
Tuple[Dict, Dict]: Model outputs and losses.
"""
with torch.no_grad():
outputs_per_optimizer = None
outputs = []
loss_dict = {}
if not isinstance(self.optimizer, list):
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
else:
outputs_per_optimizer = [None] * len(self.optimizer)
outputs = [None] * len(self.optimizer)
for idx, _ in enumerate(self.optimizer):
criterion = self.criterion
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs_per_optimizer[idx] = outputs
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs[idx] = outputs_
if loss_dict_new is not None:
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
loss_dict.update(loss_dict_new)
outputs = outputs_per_optimizer
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats
update_eval_values = dict()
@ -755,28 +792,35 @@ class Trainer:
elif hasattr(self.model, "eval_log"):
figures, audios = self.model.eval_log(self.ap, batch, outputs)
if figures is not None:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
if audios is not None:
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
self.data_eval,
verbose=True,
)
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
self.total_steps_done = self.restore_step
@ -802,10 +846,13 @@ class Trainer:
"""Where the ✨magic✨ happens..."""
try:
self._fit()
self.dashboard_logger.finish()
except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
# finish the wandb run and sync data
self.dashboard_logger.finish()
# stop without error signal
try:
sys.exit(0)
@ -816,10 +863,33 @@ class Trainer:
traceback.print_exc()
sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
# set the target loss to choose the best model
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
# save the model and update the best_loss
self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
target_loss_dict,
self.best_loss,
self.config,
self.model,
@ -834,9 +904,16 @@ class Trainer:
@staticmethod
def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
handlers = [logging.StreamHandler()]
# Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
@staticmethod
def _is_apex_available() -> bool:
@ -920,28 +997,33 @@ class Trainer:
return criterion
def init_arguments():
def getarguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def get_last_checkpoint(path):
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Args:
path (list): Path to files to be compared.
path: Path to files to be compared.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
last_checkpoint (str): Last checkpoint filename.
Path to the last checkpoint
Path to best checkpoint
"""
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
@ -963,7 +1045,7 @@ def get_last_checkpoint(path):
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = torch.load(last_model)["step"]
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
@ -997,8 +1079,8 @@ def process_args(args, config=None):
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard logging.
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
@ -1012,6 +1094,7 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:
@ -1030,12 +1113,12 @@ def process_args(args, config=None):
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name)
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training
tb_logger = None
dashboard_logger = None
if args.rank == 0:
os.makedirs(audio_path, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
@ -1043,17 +1126,20 @@ def process_args(args, config=None):
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters_config"):
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
dashboard_logger = init_logger(config)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger
return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
@ -1063,5 +1149,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
else:
parser = init_arguments()
args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger

View File

@ -13,12 +13,16 @@ class GSTConfig(Coqpit):
Args:
gst_style_input_wav (str):
Path to the wav file used to define the style of the output speech at inference. Defaults to None.
gst_style_input_weights (dict):
Defines the weights for each style token used at inference. Defaults to None.
gst_embedding_dim (int):
Defines the size of the GST embedding vector dimensions. Defaults to 256.
gst_num_heads (int):
Number of attention heads used by the multi-head attention. Defaults to 4.
gst_num_style_tokens (int):
Number of style token vectors. Defaults to 10.
"""
@ -51,17 +55,23 @@ class CharactersConfig(Coqpit):
Args:
pad (str):
characters in place of empty padding. Defaults to None.
eos (str):
characters showing the end of a sentence. Defaults to None.
bos (str):
characters showing the beginning of a sentence. Defaults to None.
characters (str):
character set used by the model. Characters not in this list are ignored when converting input text to
a list of sequence IDs. Defaults to None.
punctuations (str):
characters considered as punctuation as parsing the input sentence. Defaults to None.
phonemes (str):
characters considered as parsing phonemes. Defaults to None.
unique (bool):
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
models trained with character lists with duplicates.
@ -95,54 +105,78 @@ class BaseTTSConfig(BaseTrainingConfig):
Args:
audio (BaseAudioConfig):
Audio processor config object instance.
use_phonemes (bool):
enable / disable phoneme use.
use_espeak_phonemes (bool):
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
compute_input_seq_cache (bool):
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
the training, It allows faster data loader time and precise limitation with `max_seq_len` and
`min_seq_len`.
text_cleaner (str):
Name of the text cleaner used for cleaning and formatting transcripts.
enable_eos_bos_chars (bool):
enable / disable the use of eos and bos characters.
test_senteces_file (str):
Path to a txt file that has sentences used at test time. The file must have a sentence per line.
phoneme_cache_path (str):
Path to the output folder caching the computed phonemes for each sample.
characters (CharactersConfig):
Instance of a CharactersConfig class.
batch_group_size (int):
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
prevent using the same batches for each epoch.
loss_masking (bool):
enable / disable masking loss values against padded segments of samples in a batch.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
compute_f0 (int):
(Not in use yet).
compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.
use_noise_augment (bool):
Augment the input audio with random noise.
add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
"""
@ -166,6 +200,7 @@ class BaseTTSConfig(BaseTrainingConfig):
min_seq_len: int = 1
max_seq_len: int = float("inf")
compute_f0: bool = False
compute_linear_spec: bool = False
use_noise_augment: bool = False
add_blank: bool = False
# dataset

View File

@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs
@dataclass
class VitsConfig(BaseTTSConfig):
"""Defines parameters for VITS End2End TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.
lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.
lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.
kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.
disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.
gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.
feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.
mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
min_seq_len (int):
Minimum text length to be considered for training. Defaults to `13`.
max_seq_len (int):
Maximum text length to be considered for training. Defaults to `500`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[str]):
List of sentences to be used for testing.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs import VitsConfig
>>> config = VitsConfig()
"""
model: str = "vits"
# model specific params
model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 13
max_seq_len: int = 500
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -23,7 +23,9 @@ class TTSDataset(Dataset):
ap: AudioProcessor,
meta_data: List[List],
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
@ -173,6 +191,7 @@ class TTSDataset(Dataset):
else:
text, wav_file, speaker_name = item
attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
@ -189,13 +208,19 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
@ -209,9 +234,10 @@ class TTSDataset(Dataset):
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
return self.load_data(100)
return self.load_data(self.rescue_item_idx)
sample = {
"raw_text": raw_text,
"text": text,
"wav": wav,
"attn": attn,
@ -238,7 +264,13 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
self.items[idx][0] = sequence
@ -249,6 +281,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]
@ -329,6 +362,7 @@ class TTSDataset(Dataset):
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors
@ -347,6 +381,14 @@ class TTSDataset(Dataset):
mel_lengths = [m.shape[1] for m in mel]
# lengths adjusted by the reduction factor
mel_lengths_adjusted = [
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
if m.shape[1] % self.outputs_per_step
else m.shape[1]
for m in mel
]
# compute 'stop token' targets
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
@ -385,6 +427,20 @@ class TTSDataset(Dataset):
else:
linear = None
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# collate attention alignments
if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
@ -397,6 +453,7 @@ class TTSDataset(Dataset):
attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return (
text,
text_lenghts,
@ -409,6 +466,8 @@ class TTSDataset(Dataset):
d_vectors,
speaker_ids,
attns,
wav_padded,
raw_text,
)
raise TypeError(

View File

@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
return x
class LayerNorm2(nn.Module):
"""Layer norm for the 2nd dimension of the input using torch primitive.
Args:
channels (int): number of channels (2nd dimension) of the input.
eps (float): to prevent 0 division
Shapes:
- input: (B, C, T)
- output: (B, C, T)
"""
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TemporalBatchNorm1d(nn.BatchNorm1d):
"""Normalize each channel separately over time and batch."""

View File

@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer.
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
super().__init__()
# class arguments
self.in_channels = in_channels
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
self.norm_2 = LayerNorm(hidden_channels)
# output layer
self.proj = nn.Conv1d(hidden_channels, 1, 1)
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
def forward(self, x, x_mask):
def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if g is not None:
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)

View File

@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
::
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
|---------------> conv1d_1x1 ------------------|
Args:
in_channels (int): number of input tensor channels.

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.glow_tts.glow import LayerNorm
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
class RelativePositionMultiHeadAttention(nn.Module):
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
dropout_p (float, optional): dropout rate. Defaults to 0.
"""
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
super().__init__()
self.in_channels = in_channels
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = self.conv_1(self.padding(x * x_mask))
x = torch.relu(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding.
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
"""
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
in_channels: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size=1,
dropout_p=0.0,
rel_attn_window_size=None,
input_length=None,
rel_attn_window_size: int = None,
input_length: int = None,
layer_norm_type: str = "1",
):
super().__init__()
self.hidden_channels = hidden_channels
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
input_length=input_length,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
if layer_norm_type == "1":
self.norm_layers_1.append(LayerNorm(hidden_channels))
elif layer_norm_type == "2":
self.norm_layers_1.append(LayerNorm2(hidden_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
if layer_norm_type == "1":
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
elif layer_norm_type == "2":
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
def forward(self, x, x_mask):
"""

View File

@ -2,11 +2,13 @@ import math
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import functional
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
# pylint: disable=abstract-method
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
+ self.mdn_alpha * mdn_loss
)
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
c.audio.win_length,
sample_rate=c.audio.sample_rate,
mel_fmin=c.audio.mel_fmin,
mel_fmax=c.audio.mel_fmax,
n_mels=c.audio.num_mels,
use_mel=True,
do_amp_to_db=True,
)
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
@staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def forward(
self,
waveform,
waveform_hat,
z_p,
logs_q,
m_p,
logs_p,
z_len,
scores_disc_fake,
feats_disc_fake,
feats_disc_real,
):
"""
Shapes:
- wavefrom: :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]`
- logs_p: :math:`[B, C, T]`
- z_len: :math:`[B]`
- scores_disc_fake[i]: :math:`[B, C]`
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
"""
loss = 0.0
return_dict = {}
z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel
return_dict["loss"] = loss
return return_dict
class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha
@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc
return_dict["loss_disc"] = loss_disc
return_dict["loss"] = loss
return return_dict

View File

@ -8,10 +8,10 @@ class GST(nn.Module):
See https://arxiv.org/pdf/1803.09017"""
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None):
super().__init__()
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim)
def forward(self, inputs, speaker_embedding=None):
enc_out = self.encoder(inputs)
@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module):
class StyleTokenLayer(nn.Module):
"""NN Module attending to style tokens based on prosody encodings."""
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None):
def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
super().__init__()
self.query_dim = embedding_dim // 2
self.query_dim = gst_embedding_dim // 2
if d_vector_dim:
self.query_dim += d_vector_dim
self.key_dim = embedding_dim // num_heads
self.key_dim = gst_embedding_dim // num_heads
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
self.attention = MultiHeadAttention(
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads
)
def forward(self, inputs):

View File

@ -0,0 +1,77 @@
import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class VitsDiscriminator(nn.Module):
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
::
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
score_sd, feats_sd = self.sd(x)
return scores + [score_sd], feats + [feats_sd]

View File

@ -0,0 +1,271 @@
import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.glow import WN
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.data import sequence_mask
LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size: int,
dropout_p: float,
):
"""Text Encoder for VITS model.
Args:
n_vocab (int): Number of characters for the embedding layer.
out_channels (int): Number of channels for the output.
hidden_channels (int): Number of channels for the hidden layers.
hidden_channels_ffn (int): Number of channels for the convolutional layers.
num_heads (int): Number of attention heads for the Transformer layers.
num_layers (int): Number of Transformer layers.
kernel_size (int): Kernel size for the FFN layers in Transformer network.
dropout_p (float): Dropout rate for the Transformer layers.
"""
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
num_heads=num_heads,
num_layers=num_layers,
kernel_size=kernel_size,
dropout_p=dropout_p,
layer_norm_type="2",
rel_attn_window_size=4,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=0,
cond_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.half_channels = channels // 2
self.mean_only = mean_only
# input layer
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
# coupling layers
self.enc = WN(
hidden_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=dropout_p,
c_in_channels=cond_channels,
)
# output layer
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
log_scale = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(log_scale) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(log_scale, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ResidualCouplingBlocks(nn.Module):
def __init__(
self,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
num_flows=4,
cond_channels=0,
):
"""Redisual Coupling blocks for VITS flow layers.
Args:
channels (int): Number of input and output tensor channels.
hidden_channels (int): Number of hidden network channels.
kernel_size (int): Kernel size of the WaveNet layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
"""
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.num_flows = num_flows
self.cond_channels = cond_channels
self.flows = nn.ModuleList()
for _ in range(num_flows):
self.flows.append(
ResidualCouplingBlock(
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
cond_channels=cond_channels,
mean_only=True,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
x = torch.flip(x, [1])
else:
for flow in reversed(self.flows):
x = torch.flip(x, [1])
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
cond_channels=0,
):
"""Posterior Encoder of VITS model.
::
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of output tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of the WaveNet convolution layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.cond_channels = cond_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B, 1]`
- g: :math:`[B, C, 1]`
"""
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask

View File

@ -0,0 +1,276 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
class DilatedDepthSeparableConv(nn.Module):
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
"""Dilated Depth-wise Separable Convolution module.
::
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|-------------------------------------------------------------------------------------^
Args:
channels ([type]): [description]
kernel_size ([type]): [description]
num_layers ([type]): [description]
dropout_p (float, optional): [description]. Defaults to 0.0.
Returns:
torch.tensor: Network output masked by the input sequence mask.
"""
super().__init__()
self.num_layers = num_layers
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(num_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm2(channels))
self.norms_2.append(LayerNorm2(channels))
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
if g is not None:
x = x + g
for i in range(self.num_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.dropout(y)
x = x + y
return x * x_mask
class ElementwiseAffine(nn.Module):
"""Element-wise affine transform like no-population stats BatchNorm alternative.
Args:
channels (int): Number of input tensor channels.
"""
def __init__(self, channels):
super().__init__()
self.translation = nn.Parameter(torch.zeros(channels, 1))
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
if not reverse:
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
return y, logdet
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
return x
class ConvFlow(nn.Module):
"""Dilated depth separable convolutional based spline flow.
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of in network channels.
kernel_size (int): Convolutional kernel size.
num_layers (int): Number of convolutional layers.
num_bins (int, optional): Number of spline bins. Defaults to 10.
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
num_layers: int,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.num_bins = num_bins
self.tail_bound = tail_bound
self.hidden_channels = hidden_channels
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
return x
class StochasticDurationPredictor(nn.Module):
"""Stochastic duration predictor with Spline Flows.
It applies Variational Dequantization and Variationsl Data Augmentation.
Paper:
SDP: https://arxiv.org/pdf/2106.06103.pdf
Spline Flow: https://arxiv.org/abs/1906.04032
::
## Inference
x -> TextCondEncoder() -> Flow() -> dr_hat
noise ----------------------^
## Training
|---------------------|
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
d -> DurCondEncoder() -> ^ |
|------------------------------------------------------------------------------|
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of convolutional layers.
dropout_p (float): Dropout rate.
num_flows (int, optional): Number of flow blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
):
super().__init__()
# condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# posterior encoder
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
# condition encoder duration
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# flow layers
self.post_flows = nn.ModuleList()
self.post_flows.append(ElementwiseAffine(2))
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- dr: :math:`[B, 1, T]`
- g: :math:`[B, C]`
"""
# condition encoder text
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert dr is not None
# condition encoder duration
h = self.post_pre(dr)
h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise
# posterior encoder
logdet_tot_q = 0.0
for idx, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
logdet_tot_q = logdet_tot_q + logdet_q
if idx > 0:
z_q = torch.flip(z_q, [1])
z_u, z_v = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (dr - u) * x_mask
# posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
)
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
logdet_tot = torch.sum(-z0, [1, 2])
z = torch.cat([z0, z_v], 1)
# flow layers
for idx, flow in enumerate(flows):
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
if idx > 0:
z = torch.flip(z, [1])
# flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = torch.flip(z, [1])
z = flow(z, x_mask, g=x, reverse=reverse)
z0, _ = torch.split(z, [1, 1], 1)
logw = z0
return logw

View File

@ -0,0 +1,203 @@
# adopted from https://github.com/bayesiains/nflows
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs,
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -4,20 +4,23 @@ from TTS.utils.generic_utils import find_module
def setup_model(config):
print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model
if config.characters is not None:
# set characters from config
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name
if hasattr(MyModel, "make_symbols"):
symbols = MyModel.make_symbols(config)
else:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
if config.use_phonemes:
symbols = phonemes
# use default characters and assign them to config
config.characters = parse_symbols()
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
# consider special `blank` character if `add_blank` is set True
num_chars = num_chars + getattr(config, "add_blank", False)
num_chars = len(symbols) + getattr(config, "add_blank", False)
config.num_chars = num_chars
# compatibility fix
if "model_params" in config:

View File

@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@ -75,9 +76,6 @@ class BaseTacotron(BaseTTS):
self.decoder_backward = None
self.coarse_decoder = None
# init multi-speaker layers
self.init_multispeaker(config)
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
@ -113,7 +111,7 @@ class BaseTacotron(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if "r" in state:
self.decoder.set_r(state["r"])
@ -236,6 +234,7 @@ class BaseTacotron(BaseTTS):
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
@ -247,8 +246,10 @@ class BaseTacotron(BaseTTS):
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs

View File

@ -1,6 +1,6 @@
import os
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
If you need a different behaviour, override this function for your model.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.0
Args:
config (Coqpit): Model configuration.
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
"""
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if data is not None or self.speaker_manager.speaker_ids:
self.num_speakers = self.speaker_manager.num_speakers
else:
self.num_speakers = (
config.num_speakers
if "num_speakers" in config and config.num_speakers != 0
else self.speaker_manager.num_speakers
)
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
linear_input = batch[3]
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
waveform = batch[11]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
"max_text_length": float(max_text_length),
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
}
def get_data_loader(
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None
d_vector_mapping = None
# setup custom symbols if needed
custom_symbols = None
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron",
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
meta_data=data_items,
ap=ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
@ -185,8 +213,21 @@ class BaseTTS(BaseModel):
)
if config.use_phonemes and config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(config.num_loader_workers)
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
@ -216,7 +257,7 @@ class BaseTTS(BaseModel):
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
outputs_dict = synthesis(
self,
sen,
self.config,
@ -228,9 +269,12 @@ class BaseTTS(BaseModel):
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return test_figures, test_audios

49
TTS/tts/models/glow_tts.py Executable file → Normal file
View File

@ -12,14 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
class GlowTTS(BaseTTS):
"""Glow TTS models from https://arxiv.org/abs/2005.11129
"""GlowTTS model.
Paper abstract:
Paper::
https://arxiv.org/abs/2005.11129
Paper abstract::
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
@ -144,7 +149,6 @@ class GlowTTS(BaseTTS):
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
@ -361,12 +365,49 @@ class GlowTTS(BaseTTS):
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
@torch.no_grad()
def test_run(self, ap):
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
outputs = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs["outputs"]["model_outputs"], ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios
def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None:
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
@ -382,7 +423,7 @@ class GlowTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass
@ -105,7 +106,7 @@ class SpeedySpeech(BaseTTS):
if isinstance(config.model_args.length_scale, int)
else config.model_args.length_scale
)
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder(
config.model_args.hidden_channels,
config.model_args.hidden_channels,
@ -227,6 +228,7 @@ class SpeedySpeech(BaseTTS):
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""
Shapes:
@ -306,7 +308,7 @@ class SpeedySpeech(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -23,19 +23,19 @@ class Tacotron(BaseTacotron):
def __init__(self, config: Coqpit):
super().__init__(config)
self.num_chars, self.config = self.get_characters(config)
chars, self.config = self.get_characters(config)
config.num_chars = self.num_chars = len(chars)
# pass all config fields to `self`
# for fewer code change
for key in config:
setattr(self, key, config[key])
# speaker embedding layer
if self.num_speakers > 1:
# set speaker embedding channel size for determining `in_channels` for the connected layers.
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config)
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
if self.use_gst:
@ -75,13 +75,11 @@ class Tacotron(BaseTacotron):
if self.gst and self.use_gst:
self.gst_layer = GST(
num_mel=self.decoder_output_dim,
d_vector_dim=self.d_vector_dim
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
else None,
num_heads=self.gst.gst_num_heads,
num_style_tokens=self.gst.gst_num_style_tokens,
gst_embedding_dim=self.gst.gst_embedding_dim,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
@ -106,7 +104,9 @@ class Tacotron(BaseTacotron):
self.max_decoder_steps,
)
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
"""
Shapes:
text: [B, T_in]
@ -115,6 +115,7 @@ class Tacotron(BaseTacotron):
mel_lengths: [B]
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
"""
aux_input = self._format_aux_input(aux_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
inputs = self.embedding(text)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
@ -125,12 +126,10 @@ class Tacotron(BaseTacotron):
# global style token
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
)
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
# speaker embedding
if self.num_speakers > 1:
if not self.use_d_vectors:
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
else:
@ -182,7 +181,7 @@ class Tacotron(BaseTacotron):
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.num_speakers > 1:
if not self.use_d_vectors:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
# reshape embedded_speakers

View File

@ -23,7 +23,7 @@ class Tacotron2(BaseTacotron):
super().__init__(config)
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
# pass all config fields to `self`
@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron):
for key in config:
setattr(self, key, config[key])
# speaker embedding layer
if self.num_speakers > 1:
# set speaker embedding channel size for determining `in_channels` for the connected layers.
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config)
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
if self.use_gst:
@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron):
# base model layers
self.encoder = Encoder(self.encoder_in_features)
self.decoder = Decoder(
self.decoder_in_features,
self.decoder_output_dim,
@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron):
if self.gst and self.use_gst:
self.gst_layer = GST(
num_mel=self.decoder_output_dim,
d_vector_dim=self.d_vector_dim
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
else None,
num_heads=self.gst.gst_num_heads,
num_style_tokens=self.gst.gst_num_style_tokens,
gst_embedding_dim=self.gst.gst_embedding_dim,
@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
"""
Shapes:
text: [B, T_in]
@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron):
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
)
if self.num_speakers > 1:
if not self.use_d_vectors:
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
else:
@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron):
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.num_speakers > 1:
if not self.use_d_vectors:
if not self.use_d_vector_file:
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
# reshape embedded_speakers
if embedded_speakers.ndim == 1:

767
TTS/tts/models/vits.py Normal file
View File

@ -0,0 +1,767 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices"""
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
return segments
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
"""Create random segments based on the input lengths."""
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, ids_str, segment_size)
return ret, ids_str
@dataclass
class VitsArgs(Coqpit):
"""VITS model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 513.
spec_segment_size (int):
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
hidden_channels (int):
Number of hidden channels of the model. Defaults to 192.
hidden_channels_ffn_text_encoder (int):
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
num_heads_text_encoder (int):
Number of attention heads of the text encoder transformer. Defaults to 2.
num_layers_text_encoder (int):
Number of transformer layers in the text encoder. Defaults to 6.
kernel_size_text_encoder (int):
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
dropout_p_text_encoder (float):
Dropout rate of the text encoder. Defaults to 0.1.
dropout_p_duration_predictor (float):
Dropout rate of the duration predictor. Defaults to 0.1.
kernel_size_posterior_encoder (int):
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
dilatation_posterior_encoder (int):
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
num_layers_posterior_encoder (int):
Number of posterior encoder's WaveNet layers. Defaults to 16.
kernel_size_flow (int):
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
dilatation_flow (int):
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
num_layers_flow (int):
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
resblock_type_decoder (str):
Type of the residual block in the decoder network. Defaults to "1".
resblock_kernel_sizes_decoder (List[int]):
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
resblock_dilation_sizes_decoder (List[List[int]]):
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
upsample_rates_decoder (List[int]):
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
upsample_initial_channel_decoder (int):
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
upsample_kernel_sizes_decoder (List[int]):
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
use_sdp (int):
Use Stochastic Duration Predictor. Defaults to True.
noise_scale (float):
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
inference_noise_scale (float):
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
length_scale (int):
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
noise_scale_dp (float):
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
inference_noise_scale_dp (float):
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
max_inference_len (int):
Maximum inference length to limit the memory use. Defaults to None.
init_discriminator (bool):
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
use_spectral_norm_disriminator (bool):
Use spectral normalization over weight norm in the discriminator. Defaults to False.
use_speaker_embedding (bool):
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
speaker_embedding_channels (int):
Number of speaker embedding channels. Defaults to 256.
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
detach_dp_input (bool):
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
"""
num_chars: int = 100
out_channels: int = 513
spec_segment_size: int = 32
hidden_channels: int = 192
hidden_channels_ffn_text_encoder: int = 768
num_heads_text_encoder: int = 2
num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1
dropout_p_duration_predictor: int = 0.1
kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16
kernel_size_flow: int = 5
dilation_rate_flow: int = 1
num_layers_flow: int = 4
resblock_type_decoder: int = "1"
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
use_sdp: int = True
noise_scale: float = 1.0
inference_noise_scale: float = 0.667
length_scale: int = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0
detach_dp_input: bool = True
class Vits(BaseTTS):
"""VITS TTS model
Paper::
https://arxiv.org/pdf/2106.06103.pdf
Paper Abstract::
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
natural one-to-many relationship in which a text input can be spoken in multiple ways
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
available TTS systems and achieves a MOS comparable to ground truth.
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits(config)
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
self.END2END = True
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
else:
self.config = config
config.model_args.num_chars = config.num_chars
args = self.config.model_args
elif isinstance(config, VitsArgs):
# loading from VitsArgs
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or VitsArgs")
self.args = args
self.init_multispeaker(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
self.inference_noise_scale = args.inference_noise_scale
self.inference_noise_scale_dp = args.inference_noise_scale_dp
self.noise_scale_dp = args.noise_scale_dp
self.max_inference_len = args.max_inference_len
self.spec_segment_size = args.spec_segment_size
self.text_encoder = TextEncoder(
args.num_chars,
args.hidden_channels,
args.hidden_channels,
args.hidden_channels_ffn_text_encoder,
args.num_heads_text_encoder,
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
)
self.posterior_encoder = PosteriorEncoder(
args.out_channels,
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_posterior_encoder,
dilation_rate=args.dilation_rate_posterior_encoder,
num_layers=args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim,
)
self.flow = ResidualCouplingBlocks(
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_flow,
dilation_rate=args.dilation_rate_flow,
num_layers=args.num_layers_flow,
cond_channels=self.embedded_speaker_dim,
)
if args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels,
192,
3,
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
)
self.waveform_decoder = HifiganGenerator(
args.hidden_channels,
1,
args.resblock_type_decoder,
args.resblock_dilation_sizes_decoder,
args.resblock_kernel_sizes_decoder,
args.upsample_kernel_sizes_decoder,
args.upsample_initial_channel_decoder,
args.upsample_rates_decoder,
inference_padding=0,
cond_channels=self.embedded_speaker_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
)
if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
if hasattr(config, "model_args"):
config = config.model_args
self.embedded_speaker_dim = 0
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
self.speaker_manager.num_speakers = config.num_speakers
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
# init d-vector usage
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g = None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = aux_input["d_vectors"]
return sid, g
def forward(
self,
x: torch.tensor,
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None},
) -> Dict:
"""Forward pass of the model.
Args:
x (torch.tensor): Batch of input character sequence IDs.
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
Shapes:
- x: :math:`[B, T_seq]`
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
outputs = {}
sid, g = self._set_cond_input(aux_input)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
# speaker embedding
if self.num_speakers > 1 and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
nll_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
outputs["nll_duration"] = nll_duration
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
}
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
sid, g = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
if self.num_speakers > 0 and sid:
g = self.emb_g(sid).unsqueeze(-1)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
else:
logw = self.duration_predictor(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
"""TODO: create an end-point for voice conversion"""
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
waveform = batch["waveform"]
# generator pass
outputs = self.forward(
text_input,
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
wav_seg = segment(
waveform.transpose(1, 2),
outputs["slice_ids"] * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
self.wav_seg_disc_cache = wav_seg
outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
_, outputs["feats_disc_real"] = self.disc(wav_seg)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform=wav_seg.float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
logs_p=outputs["logs_p"].float(),
z_len=mel_lengths,
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
outputs["scores_disc_real"],
outputs["scores_disc_fake"],
)
return outputs, loss_dict
def train_log(
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
): # pylint: disable=no-self-use
"""Create visualizations and waveform examples.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
be projected onto Tensorboard.
Args:
ap (AudioProcessor): audio processor used at training.
batch (Dict): Model inputs used at the previous training step.
outputs (Dict): Model outputs generated at the previoud training step.
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
y_hat = outputs[0]["model_outputs"]
y = outputs[0]["waveform_seg"]
figures = plot_results(y_hat, y, ap, name_prefix)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {f"{name_prefix}/audio": sample_voice}
alignments = outputs[0]["alignments"]
align_img = alignments[0].data.cpu().numpy().T
figures.update(
{
"alignment": plot_alignment(align_img, output_fig=False),
}
)
return figures, audios
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs, "eval")
@torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
return test_figures, test_audios
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
self.disc.requires_grad_(False)
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
self.disc.requires_grad_(True)
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
`train_step()`"""
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
VitsDiscriminatorLoss,
VitsGeneratorLoss,
)
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
@staticmethod
def make_symbols(config):
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
whole training and inference steps."""
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
symbols = [_pad] + list(_punctuations) + list(_letters)
if config.use_phonemes:
symbols += list(_letters_ipa)
return symbols
@staticmethod
def get_characters(config: Coqpit):
if config.characters is not None:
symbols = Vits.make_symbols(config)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
if config.use_phonemes:
symbols = phonemes
num_chars = len(symbols) + getattr(config, "add_blank", False)
return symbols, config, num_chars
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training

View File

@ -2,6 +2,7 @@ import datetime
import importlib
import pickle
import fsspec
import numpy as np
import tensorflow as tf
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,6 +1,7 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, "wb") as f:
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model

23
TTS/tts/utils/speakers.py Executable file → Normal file
View File

@ -3,6 +3,7 @@ import os
import random
from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
@ -84,12 +85,12 @@ class SpeakerManager:
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
@ -294,9 +295,10 @@ def _set_file_path(path):
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
path_continue = os.path.join(path, "speakers.json")
if os.path.exists(path_restore):
fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore
if os.path.exists(path_continue):
if fs.exists(path_continue):
return path_continue
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
json_file = out_path
else:
json_file = _set_file_path(out_path)
with open(json_file) as f:
with fsspec.open(json_file, "r") as f:
return json.load(f)
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = _set_file_path(out_path)
with open(speakers_json_path, "w") as f:
with fsspec.open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
@ -358,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif c.use_d_vector_file and c.d_vector_file:
# new speaker manager with external speaker embeddings.
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file.
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
elif c.use_d_vector_file and not c.d_vector_file:
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
print(
" > Training with {} speakers: {}".format(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)

View File

@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
def text_to_seq(text, CONFIG):
def text_to_seq(text, CONFIG, custom_symbols=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
),
dtype=np.int32,
)
else:
seq = np.asarray(
text_to_sequence(
text,
text_cleaner,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
),
dtype=np.int32,
)
@ -229,13 +227,16 @@ def synthesis(
"""
# GST processing
style_mel = None
custom_symbols = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text
text_inputs = text_to_seq(text, CONFIG)
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
@ -274,15 +275,18 @@ def synthesis(
# convert outputs to numpy
# plot results
wav = None
if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
if hasattr(model, "END2END") and model.END2END:
wav = model_outputs.squeeze(0)
else:
if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return_dict = {
"wav": wav,
"alignments": alignments,
"model_outputs": model_outputs,
"text_inputs": text_inputs,
"outputs": outputs,
}
return return_dict

View File

@ -2,10 +2,9 @@
# adapted from https://github.com/keithito/tacotron
import re
import unicodedata
from typing import Dict, List
import gruut
from packaging import version
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE)
print(" > Phonemes: {}".format(ph))
return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@ -106,13 +105,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
def phoneme_to_sequence(
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
):
text: str,
cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _phonemes_to_id, _phonemes
if tp:
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
@ -127,20 +151,22 @@ def phoneme_to_sequence(
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
# pylint: disable=global-statement
"""Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = ""
if tp:
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence:
if symbol_id in _id_to_phonemes:
@ -149,27 +175,32 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
return result.replace("}{", " ")
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
If `custom_symbols` is provided, it will override the default symbols.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
tp: dictionary of character parameters to use a custom character set.
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns:
List of integers corresponding to the symbols in the text
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _symbol_to_id, _symbols
if tp:
if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while text:
m = _CURLY_RE.match(text)
@ -185,16 +216,18 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
return sequence
def sequence_to_text(sequence, tp=None, add_blank=False):
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
"""Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement
global _id_to_symbol, _symbols
if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence))
if tp:
if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = ""
for symbol_id in sequence:

View File

@ -28,10 +28,10 @@ def make_symbols(
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
) # this is to keep previous models compatible.
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ["@" + s for s in _phonemes_sorted]
# _arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
_symbols += _arpabet
# _symbols += _arpabet
return _symbols, _phonemes

View File

@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
"""
def __init__(
self,
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
):
super().__init__()
self.n_fft = n_fft
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
@ -87,6 +96,14 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
@ -97,33 +114,93 @@ class AudioProcessor(object):
of the class with the model config. They are not meaningful for all the arguments.
Args:
sample_rate (int, optional): target audio sampling rate. Defaults to None.
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional): FFT window size for STFT. Defaults to 1024.
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False.
stats_path (str, optional): Path to the computed stats file. Defaults to None.
verbose (bool, optional): enable/disable logging. Defaults to True.
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
resample (bool, optional):
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
log_func (int, optional):
log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional):
minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional):
milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional):
milliseconds of STFT window length. Defaults to None.
hop_length (int, optional):
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
win_length (int, optional):
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional):
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional):
FFT window size for STFT. Defaults to 1024.
power (int, optional):
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional):
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional):
enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional):
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional):
```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional):
Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional):
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional):
Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional):
enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional):
DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional):
enable/disable signal normalization. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
"""
def __init__(
@ -153,6 +230,8 @@ class AudioProcessor(object):
do_trim_silence=False,
trim_db=60,
do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
stats_path=None,
verbose=True,
**_,
@ -181,6 +260,8 @@ class AudioProcessor(object):
self.do_trim_silence = do_trim_silence
self.trim_db = trim_db
self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
self.stats_path = stats_path
# setup exp_func for db to amp conversion
if log_func == "np.log":
@ -381,7 +462,6 @@ class AudioProcessor(object):
Returns:
np.ndarray: Decibels spectrogram.
"""
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
# pylint: disable=no-self-use
@ -448,7 +528,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(np.abs(D))
if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D))
else:
S = np.abs(D)
return self.normalize(S).astype(np.float32)
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
@ -457,7 +540,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
else:
S = self._linear_to_mel(np.abs(D))
return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:

View File

@ -1,15 +1,14 @@
# -*- coding: utf-8 -*-
import datetime
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Dict
import fsspec
import torch
@ -58,23 +57,22 @@ def get_commit_hash():
return commit
def create_experiment_folder(root_path, model_name):
"""Create a folder with the current date and time"""
def get_experiment_folder_path(root_path, model_name):
"""Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
if not checkpoint_files:
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path, ignore_errors=True)
if fs.exists(experiment_path):
fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))

View File

@ -1,9 +1,11 @@
import datetime
import glob
import json
import os
import pickle as pickle_tts
from shutil import copyfile
import shutil
from typing import Any
import fsspec
import torch
from coqpit import Coqpit
@ -24,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self
def copy_model_files(config, out_path, new_fields):
def copy_model_files(config: Coqpit, out_path, new_fields):
"""Copy config.json and other model files to training folder and add
new fields.
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
config.update(new_fields, allow_new=True)
config.save_json(copy_config_path)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path):
copyfile(
config.audio.stats_path,
copy_stats_path,
)
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
torch.save(state, output_path)
save_fsspec(state, output_path)
def save_checkpoint(
@ -147,18 +178,16 @@ def save_best_model(
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names:
if os.path.basename(model_name) == best_model_name:
continue
os.remove(model_name)
# create symlink to best model for convinience
link_name = "best_model.pth.tar"
link_path = os.path.join(out_path, link_name)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth.tar"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss

View File

@ -1,2 +1,24 @@
from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config):
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
return dashboard_logger

View File

@ -38,7 +38,7 @@ class ConsoleLogger:
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict):
def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
@ -50,13 +50,6 @@ class ConsoleLogger:
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
for idx, (key, value) in enumerate(log_dict.items()):
if isinstance(value, list):
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
else:
log_text += f"{indent}{key}: {value}"
if idx < len(log_dict) - 1:
log_text += "\n"
print(log_text, flush=True)
# pylint: disable=unused-argument

View File

@ -7,10 +7,8 @@ class TensorboardLogger(object):
def __init__(self, log_dir, model_name):
self.model_name = model_name
self.writer = SummaryWriter(log_dir)
self.train_stats = {}
self.eval_stats = {}
def tb_model_weights(self, model, step):
def model_weights(self, model, step):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
@ -41,32 +39,41 @@ class TensorboardLogger(object):
except RuntimeError:
traceback.print_exc()
def tb_train_step_stats(self, step, stats):
def train_step_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats):
def train_epoch_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
def tb_train_figures(self, step, figures):
def train_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
def tb_train_audios(self, step, audios, sample_rate):
def train_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
def tb_eval_stats(self, step, stats):
def eval_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
def tb_eval_figures(self, step, figures):
def eval_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
def tb_eval_audios(self, step, audios, sample_rate):
def eval_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
def tb_test_audios(self, step, audios, sample_rate):
def test_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
def tb_test_figures(self, step, figures):
def test_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
def tb_add_text(self, title, text, step):
def add_text(self, title, text, step):
self.writer.add_text(title, text, step)
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
yield
def flush(self):
self.writer.flush()
def finish(self):
self.writer.close()

View File

@ -0,0 +1,111 @@
# pylint: disable=W0613
import traceback
from pathlib import Path
try:
import wandb
from wandb import finish, init # pylint: disable=W0611
except ImportError:
wandb = None
class WandbLogger:
def __init__(self, **kwargs):
if not wandb:
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
self.run = None
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.model_name = self.run.config.model
self.log_dict = {}
def model_weights(self, model):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()})
else:
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
layer_num += 1
def dict_to_scalar(self, scope_name, stats):
for key, value in stats.items():
self.log_dict["{}/{}".format(scope_name, key)] = value
def dict_to_figure(self, scope_name, figures):
for key, value in figures.items():
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
def dict_to_audios(self, scope_name, audios, sample_rate):
for key, value in audios.items():
if value.dtype == "float16":
value = value.astype("float32")
try:
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
except RuntimeError:
traceback.print_exc()
def log(self, log_dict, prefix="", flush=False):
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
if flush: # for cases where you don't want to accumulate data
self.flush()
def train_step_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
def train_epoch_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
def train_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
def train_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
def eval_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
def eval_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
def eval_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
def test_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
def test_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
def add_text(self, title, text, step):
pass
def flush(self):
if self.run:
wandb.log(self.log_dict)
self.log_dict = {}
def finish(self):
if self.run:
self.run.finish()
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
if not self.run:
return
name = "_".join([self.run.id, name])
artifact = wandb.Artifact(name, type=artifact_type)
data_path = Path(file_or_dir)
if data_path.is_dir():
artifact.add_dir(str(data_path))
elif data_path.is_file():
artifact.add_file(str(data_path))
self.run.log_artifact(artifact, aliases=aliases)

View File

@ -64,6 +64,7 @@ class ModelManager(object):
def list_models(self):
print(" Name format: type/language/dataset/model")
models_name_list = []
model_count = 1
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
@ -71,10 +72,11 @@ class ModelManager(object):
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" >: {model_type}/{lang}/{dataset}/{model}")
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return models_name_list
def download_model(self, model_name):

View File

@ -12,7 +12,6 @@ from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence
from TTS.tts.utils.text import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@ -103,6 +102,34 @@ class Synthesizer(object):
self.num_speakers = self.speaker_manager.num_speakers
self.d_vector_dim = self.speaker_manager.d_vector_dim
def _set_tts_speaker_file(self):
"""Set the TTS speaker file used by a multi-speaker model."""
# setup if multi-speaker settings are in the global model config
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
if self.tts_config.use_d_vector_file:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
)
self.tts_config["d_vector_file"] = self.tts_speakers_file
else:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"]
)
# setup if multi-speaker settings are in the model args config
if (
self.tts_speakers_file is None
and hasattr(self.tts_config, "model_args")
and hasattr(self.tts_config.model_args, "use_speaker_embedding")
and self.tts_config.model_args.use_speaker_embedding
):
_args = self.tts_config.model_args
if _args.use_d_vector_file:
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"]
_args["d_vector_file"] = self.tts_speakers_file
else:
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"]
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model.
@ -113,29 +140,15 @@ class Synthesizer(object):
"""
# pylint: disable=global-statement
global symbols, phonemes
self.tts_config = load_config(tts_config_path)
self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
if self.tts_config.has("characters") and self.tts_config.characters:
symbols, phonemes = make_symbols(**self.tts_config.characters)
if self.use_phonemes:
self.input_size = len(phonemes)
else:
self.input_size = len(symbols)
if self.tts_config.use_speaker_embedding is True:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
)
self.tts_config["d_vector_file"] = self.tts_speakers_file
self.tts_model = setup_tts_model(config=self.tts_config)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
self._set_tts_speaker_file()
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model.
@ -187,15 +200,22 @@ class Synthesizer(object):
"""
start_time = time.time()
wavs = []
speaker_embedding = None
sens = self.split_into_sentences(text)
print(" > Text splitted to sentences.")
print(sens)
# handle multi-speaker
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file:
# get the speaker embedding from the saved d_vectors.
if speaker_idx and isinstance(speaker_idx, str):
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
elif not speaker_idx and not speaker_wav:
raise ValueError(
" [!] Look like you use a multi-speaker model. "
@ -224,14 +244,14 @@ class Synthesizer(object):
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
ap=self.ap,
speaker_id=None,
speaker_id=speaker_id,
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
)
waveform = outputs["wav"]
mel_postnet_spec = outputs["model_outputs"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T

View File

@ -1,5 +1,5 @@
import importlib
from typing import Dict
from typing import Dict, List
import torch
@ -48,7 +48,7 @@ def get_scheduler(
def get_optimizer(
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
) -> torch.optim.Optimizer:
"""Find, initialize and return a optimizer.
@ -66,4 +66,6 @@ def get_optimizer(
optimizer = getattr(module, "RAdam")
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=lr, **optimizer_params)
if model is not None:
parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params)

View File

@ -24,8 +24,10 @@ def setup_model(config: Coqpit):
elif config.model.lower() == "wavegrad":
MyModel = getattr(MyModel, "Wavegrad")
else:
MyModel = getattr(MyModel, to_camel(config.model))
raise ValueError(f"Model {config.model} not exist!")
try:
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
model = MyModel(config)
return model

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)

View File

@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
"""
def __init__(self):
def __init__(self, use_spectral_norm=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
]
)
@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
x (Tensor): input waveform.
Returns:
[List[Tensor]]: list of scores from each discriminator.
[List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes:

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
@ -168,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
@ -216,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
def forward(self, x):
if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
def forward(self, x, g=None):
"""
Args:
x (Tensor): conditioning input tensor.
x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
@ -231,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T]
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
@ -275,7 +292,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -2,6 +2,7 @@ import torch
from torch import nn
from torch.nn.utils import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -3,6 +3,7 @@ import math
import numpy as np
import torch
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -1,3 +1,5 @@
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
cond_channels,
upsample_factors,
lvc_layers_each_block,
lvc_kernel_size,
kpnet_hidden_channels,
kpnet_conv_size,
dropout,
in_channels: int,
out_channels: int,
hidden_channels: int,
cond_channels: int,
upsample_factors: List[int],
lvc_layers_each_block: int,
lvc_kernel_size: int,
kpnet_hidden_channels: int,
kpnet_conv_size: int,
dropout: float,
use_weight_norm=True,
):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__()
self.in_channels = in_channels

View File

@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -1,6 +1,7 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
"""Load TF Vocoder model"""
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, "wb") as f:
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model

View File

@ -1,8 +1,11 @@
from typing import Dict
import numpy as np
import torch
from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec):
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec
def plot_results(y_hat, y, ap, name_prefix):
"""Plot vocoder model results"""
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
if name_prefix is None:
name_prefix = ""
# select an instance from batch
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
y = y[0].squeeze(0).detach().cpu().numpy()
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze().detach().cpu().numpy()
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T

View File

@ -2,4 +2,5 @@ furo
myst-parser == 0.15.1
sphinx == 4.0.2
sphinx_inline_tabs
sphinx_copybutton
sphinx_copybutton
linkify-it-py

View File

@ -68,6 +68,8 @@ extensions = [
"sphinx_inline_tabs",
]
myst_enable_extensions = ['linkify',]
# 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel',

View File

@ -44,6 +44,7 @@
:caption: `tts` Models
models/glow_tts.md
models/vits.md
.. toctree::
:maxdepth: 2

View File

@ -0,0 +1,22 @@
# Glow TTS
Glow TTS is a normalizing flow model for text-to-speech. It is built on the generic Glow model that is previously
used in computer vision and vocoder models. It uses "monotonic alignment search" (MAS) to fine the text-to-speech alignment
and uses the output to train a separate duration predictor network for faster inference run-time.
## Important resources & papers
- GlowTTS: https://arxiv.org/abs/2005.11129
- Glow (Generative Flow with invertible 1x1 Convolutions): https://arxiv.org/abs/1807.03039
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
## GlowTTS Config
```{eval-rst}
.. autoclass:: TTS.tts.configs.glow_tts_config.GlowTTSConfig
:members:
```
## GlowTTS Model
```{eval-rst}
.. autoclass:: TTS.tts.models.glow_tts.GlowTTS
:members:
```

View File

@ -0,0 +1,33 @@
# VITS
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
It is a feed-forward model with x67.12 real-time factor on a GPU.
## Important resources & papers
- VITS: https://arxiv.org/pdf/2106.06103.pdf
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
- Generative Adversarial Networks: https://arxiv.org/abs/1406.2661
- HiFiGAN: https://arxiv.org/abs/2010.05646
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
## VitsConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.vits_config.VitsConfig
:members:
```
## VitsArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.vits.VitsArgs
:members:
```
## Vits Model
```{eval-rst}
.. autoclass:: TTS.tts.models.vits.Vits
:members:
```

View File

@ -85,6 +85,7 @@ We still support running training from CLI like in the old days. The same traini
```json
{
"run_name": "my_run",
"model": "glow_tts",
"batch_size": 32,
"eval_batch_size": 16,

View File

@ -25,9 +25,7 @@
"import umap\n",
"\n",
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
"from TTS.utils.audio import AudioProcessor
\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.generic_utils import load_config\n",
"\n",
"from bokeh.io import output_notebook, show\n",

View File

@ -1,12 +1,12 @@
import os
from TTS.tts.configs import AlignTTSConfig
from TTS.tts.configs import BaseDatasetConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
config = AlignTTSConfig(
batch_size=32,
eval_batch_size=16,
@ -23,8 +23,8 @@ config = AlignTTSConfig(
print_eval=True,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config]
datasets=[dataset_config],
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -1,12 +1,12 @@
import os
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.configs import BaseDatasetConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
config = GlowTTSConfig(
batch_size=32,
eval_batch_size=16,
@ -23,8 +23,8 @@ config = GlowTTSConfig(
print_eval=True,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config]
datasets=[dataset_config],
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -24,6 +24,6 @@ config = HifiganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -1,8 +1,7 @@
import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import MultibandMelganConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
output_path = os.path.dirname(os.path.abspath(__file__))
config = MultibandMelganConfig(
@ -25,6 +24,6 @@ config = MultibandMelganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -1,6 +1,5 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import UnivnetConfig
@ -25,6 +24,6 @@ config = UnivnetConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -0,0 +1,52 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)
config = VitsConfig(
audio=audio_config,
run_name="vits_ljspeech",
batch_size=48,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
print_eval=True,
mixed_precision=True,
max_seq_len=5000,
output_path=output_path,
datasets=[dataset_config],
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
trainer.fit()

View File

@ -1,10 +1,8 @@
import os
from TTS.trainer import Trainer, init_training
from TTS.trainer import TrainingArgs
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import WavegradConfig
output_path = os.path.dirname(os.path.abspath(__file__))
config = WavegradConfig(
batch_size=32,
@ -24,6 +22,6 @@ config = WavegradConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()

View File

@ -1,9 +1,8 @@
import os
from TTS.trainer import Trainer, init_training, TrainingArgs
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import WavernnConfig
output_path = os.path.dirname(os.path.abspath(__file__))
config = WavernnConfig(
batch_size=64,
@ -25,6 +24,6 @@ config = WavernnConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
trainer.fit()

View File

@ -24,3 +24,4 @@ mecab-python3==1.0.3
unidic-lite==1.0.8
# gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
fsspec>=2021.04.0

View File

@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
r,
c.text_cleaner,
compute_linear_spec=True,
return_wav=True,
ap=self.ap,
meta_data=items,
characters=c.characters,
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
wavs = data[11]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert isinstance(speaker_name[0], str)
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio["num_mels"]
assert (
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
assert abs(mel_diff.sum()) < 1e-5
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm

View File

@ -27,6 +27,7 @@ config = AlignTTSConfig(
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)

View File

@ -29,6 +29,7 @@ config = Tacotron2Config(
"Be a voice, not an echo.",
],
d_vector_file="tests/data/ljspeech/speakers.json",
d_vector_dim=256,
max_decoder_steps=50,
)

View File

@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
class TacotronTrainTest(unittest.TestCase):
"""Test vanilla Tacotron2 model."""
def test_train_step(self): # pylint: disable=no-self-use
config = config_global.copy()
config.use_speaker_embedding = False
config.num_speakers = 1
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(config).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5):
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class MultiSpeakerTacotronTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with speaker embedding layer"""
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 5
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron2(config).to(device)
model.train()
model_ref = copy.deepcopy(model)
@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase):
count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.rand(8, 55).to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron2(config).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5):
outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
)
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class TacotronGSTTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding"""
# pylint: disable=no-self-use
def test_train_step(self):
# with random gst mel style
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 10
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase):
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_d_vector_file = True
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]

View File

@ -0,0 +1,55 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = Tacotron2Config(
r=5,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
test_sentences=[
"Be a voice, not an echo.",
],
print_eval=True,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
f"--coqpit.output_path file://{output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
)
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = False
config.num_speakers = 1
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for _ in range(5):
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 5
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase):
count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_embeddings = torch.rand(8, 55).to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for _ in range(5):
outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
)
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class TacotronGSTTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 10
config.use_gst = True
config.gst = GSTConfig()
# with random gst mel style
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_d_vector_file = True
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128

View File

@ -0,0 +1,54 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)