mirror of https://github.com/coqui-ai/TTS.git
commit
01a2b0b5c0
|
@ -1,10 +1,15 @@
|
|||
---
|
||||
name: 'Contribution Guideline '
|
||||
about: Refer to Contirbution Guideline
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
# Pull request guidelines
|
||||
|
||||
---
|
||||
Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support!
|
||||
|
||||
👐 Please check our [CONTRIBUTION GUIDELINE](https://github.com/coqui-ai/TTS#contribution-guidelines).
|
||||
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
|
||||
|
||||
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
|
||||
|
||||
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS).
|
||||
|
||||
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS):
|
||||
|
||||
- Protects you, Coqui, and the users of the code.
|
||||
- Does not change your rights to use your contributions for any purpose.
|
||||
- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.
|
||||
|
|
|
@ -136,6 +136,7 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
|
|||
temp_build/*
|
||||
recipes/WIP/*
|
||||
recipes/ljspeech/LJSpeech-1.1/*
|
||||
recipes/ljspeech/tacotron2-DDC/LJSpeech-1.1/*
|
||||
events.out*
|
||||
old_configs/*
|
||||
model_importers/*
|
||||
|
@ -152,4 +153,6 @@ output.wav
|
|||
tts_output.wav
|
||||
deps.json
|
||||
speakers.json
|
||||
internal/*
|
||||
internal/*
|
||||
*_pitch.npy
|
||||
*_phoneme.npy
|
2
Makefile
2
Makefile
|
@ -4,7 +4,7 @@
|
|||
help:
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
target_dirs := tests TTS notebooks
|
||||
target_dirs := tests TTS notebooks recipes
|
||||
|
||||
test_all: ## run tests and don't stop on an error.
|
||||
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||
|
|
|
@ -73,10 +73,13 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
|
||||
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
|
||||
|
||||
### End-to-End Models
|
||||
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
|
||||
|
||||
### Attention Methods
|
||||
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
|
||||
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
|
||||
- Graves Attention: [paper](https://arxiv.org/abs/1907.09006)
|
||||
- Graves Attention: [paper](https://arxiv.org/abs/1910.10288)
|
||||
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
|
||||
- Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf)
|
||||
|
||||
|
|
179
TTS/.models.json
179
TTS/.models.json
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"tts_models":{
|
||||
"en":{
|
||||
"ek1":{
|
||||
"tts_models": {
|
||||
"en": {
|
||||
"ek1": {
|
||||
"tacotron2": {
|
||||
"description": "EK1 en-rp tacotron2 by NMStoker",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
|
||||
|
@ -9,7 +9,7 @@
|
|||
"commit": "c802255"
|
||||
}
|
||||
},
|
||||
"ljspeech":{
|
||||
"ljspeech": {
|
||||
"tacotron2-DDC": {
|
||||
"description": "Tacotron2 with Double Decoder Consistency.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip",
|
||||
|
@ -17,9 +17,18 @@
|
|||
"commit": "bae2ad0f",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"glow-tts":{
|
||||
"tacotron2-DDC_ph": {
|
||||
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
|
||||
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
|
||||
"commit": "3900448",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"glow-tts": {
|
||||
"description": "",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip",
|
||||
"stats_file": null,
|
||||
|
@ -27,7 +36,7 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"tacotron2-DCA": {
|
||||
"description": "",
|
||||
|
@ -36,19 +45,28 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"speedy-speech-wn":{
|
||||
"speedy-speech-wn": {
|
||||
"description": "Speedy Speech model with wavenet decoder.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
|
||||
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
|
||||
"commit": "77b6145",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"vits": {
|
||||
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "3900448",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
},
|
||||
"vctk":{
|
||||
"vctk": {
|
||||
"sc-glow-tts": {
|
||||
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
|
||||
|
@ -56,12 +74,19 @@
|
|||
"commit": "b531fa69",
|
||||
"author": "Edresson Casanova",
|
||||
"license": "",
|
||||
"contact":""
|
||||
|
||||
|
||||
"contact": ""
|
||||
},
|
||||
"vits": {
|
||||
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--vctk--vits.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "3900448",
|
||||
"author": "Eren @erogol",
|
||||
"license": "",
|
||||
"contact": "egolge@coqui.ai"
|
||||
}
|
||||
},
|
||||
"sam":{
|
||||
"sam": {
|
||||
"tacotron-DDC": {
|
||||
"description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip",
|
||||
|
@ -69,46 +94,47 @@
|
|||
"commit": "bae2ad0f",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"es":{
|
||||
"mai":{
|
||||
"tacotron2-DDC":{
|
||||
"es": {
|
||||
"mai": {
|
||||
"tacotron2-DDC": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip",
|
||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"fr":{
|
||||
"mai":{
|
||||
"tacotron2-DDC":{
|
||||
"fr": {
|
||||
"mai": {
|
||||
"tacotron2-DDC": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip",
|
||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"zh-CN":{
|
||||
"baker":{
|
||||
"tacotron2-DDC-GST":{
|
||||
"zh-CN": {
|
||||
"baker": {
|
||||
"tacotron2-DDC-GST": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
|
||||
"commit": "unknown",
|
||||
"author": "@kirianguiller"
|
||||
"author": "@kirianguiller",
|
||||
"default_vocoder": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"nl":{
|
||||
"mai":{
|
||||
"tacotron2-DDC":{
|
||||
"nl": {
|
||||
"mai": {
|
||||
"tacotron2-DDC": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip",
|
||||
"author": "@r-dh",
|
||||
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
|
||||
|
@ -117,20 +143,9 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"ru":{
|
||||
"ruslan":{
|
||||
"tacotron2-DDC":{
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--ru--ruslan--tacotron2-DDC.zip",
|
||||
"author": "@erogol",
|
||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||
"license":"",
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"de":{
|
||||
"thorsten":{
|
||||
"tacotron2-DCA":{
|
||||
"de": {
|
||||
"thorsten": {
|
||||
"tacotron2-DCA": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
|
||||
"default_vocoder": "vocoder_models/de/thorsten/fullband-melgan",
|
||||
"author": "@thorstenMueller",
|
||||
|
@ -138,9 +153,9 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"ja":{
|
||||
"kokoro":{
|
||||
"tacotron2-DDC":{
|
||||
"ja": {
|
||||
"kokoro": {
|
||||
"tacotron2-DDC": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip",
|
||||
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
|
||||
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
|
||||
|
@ -150,54 +165,62 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"vocoder_models":{
|
||||
"universal":{
|
||||
"libri-tts":{
|
||||
"wavegrad":{
|
||||
"vocoder_models": {
|
||||
"universal": {
|
||||
"libri-tts": {
|
||||
"wavegrad": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip",
|
||||
"commit": "ea976b0",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"fullband-melgan":{
|
||||
"fullband-melgan": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip",
|
||||
"commit": "4132240",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"en": {
|
||||
"ek1":{
|
||||
"ek1": {
|
||||
"wavegrad": {
|
||||
"description": "EK1 en-rp wavegrad by NMStoker",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip",
|
||||
"commit": "c802255"
|
||||
}
|
||||
},
|
||||
"ljspeech":{
|
||||
"multiband-melgan":{
|
||||
"ljspeech": {
|
||||
"multiband-melgan": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip",
|
||||
"commit": "ea976b0",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com"
|
||||
"contact": "egolge@coqui.com"
|
||||
},
|
||||
"hifigan_v2":{
|
||||
"hifigan_v2": {
|
||||
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip",
|
||||
"commit": "bae2ad0f",
|
||||
"author": "@erogol",
|
||||
"license": "",
|
||||
"contact": "egolge@coqui.ai"
|
||||
},
|
||||
"univnet": {
|
||||
"description": "UnivNet model trained on LJSpeech to complement the TacotronDDC_ph model.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--en--ljspeech--univnet.zip",
|
||||
"commit": "3900448",
|
||||
"author": "Eren @erogol",
|
||||
"license": "",
|
||||
"contact": "egolge@coqui.ai"
|
||||
}
|
||||
},
|
||||
"vctk":{
|
||||
"hifigan_v2":{
|
||||
"vctk": {
|
||||
"hifigan_v2": {
|
||||
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
|
||||
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
|
||||
"commit": "2f07160",
|
||||
"author": "Edresson Casanova",
|
||||
"license": "",
|
||||
|
@ -205,9 +228,9 @@
|
|||
}
|
||||
},
|
||||
"sam": {
|
||||
"hifigan_v2":{
|
||||
"hifigan_v2": {
|
||||
"description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC",
|
||||
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
|
||||
"commit": "2f07160",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
|
@ -215,28 +238,38 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"nl":{
|
||||
"mai":{
|
||||
"parallel-wavegan":{
|
||||
"nl": {
|
||||
"mai": {
|
||||
"parallel-wavegan": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip",
|
||||
"author": "@r-dh",
|
||||
"commit": "unknown"
|
||||
}
|
||||
}
|
||||
},
|
||||
"de":{
|
||||
"thorsten":{
|
||||
"wavegrad":{
|
||||
"de": {
|
||||
"thorsten": {
|
||||
"wavegrad": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip",
|
||||
"author": "@thorstenMueller",
|
||||
"commit": "unknown"
|
||||
},
|
||||
"fullband-melgan":{
|
||||
"fullband-melgan": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip",
|
||||
"author": "@thorstenMueller",
|
||||
"commit": "unknown"
|
||||
}
|
||||
}
|
||||
},
|
||||
"ja": {
|
||||
"kokoro": {
|
||||
"hifigan_v1": {
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--ja--kokoro--hifigan_v1.zip",
|
||||
"description": "HifiGAN model trained for kokoro dataset by @kaiidams",
|
||||
"author": "@kaiidams",
|
||||
"commit": "3900448"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||
compare_torch_tf,
|
||||
convert_tf_name,
|
||||
|
@ -33,7 +33,7 @@ num_speakers = 0
|
|||
|
||||
# init torch model
|
||||
model = setup_generator(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.remove_weight_norm()
|
||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
|
|||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
|
||||
sys.path.append("/home/erogol/Projects")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
@ -32,7 +32,7 @@ num_speakers = 0
|
|||
|
||||
# init torch model
|
||||
model = setup_model(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
|
|||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = setup_model(c)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if use_cuda:
|
||||
|
|
|
@ -208,7 +208,7 @@ def main():
|
|||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
# CASE3: set custome model paths
|
||||
# CASE3: set custom model paths
|
||||
if args.model_path is not None:
|
||||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
|
|
|
@ -17,6 +17,7 @@ from TTS.trainer import init_training
|
|||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, check_update
|
||||
|
||||
|
@ -115,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
|||
"step_time": step_time,
|
||||
"avg_loader_time": avg_loader_time,
|
||||
}
|
||||
tb_logger.tb_train_epoch_stats(global_step, train_stats)
|
||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
# FIXME: not constant
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
dashboard_logger.train_figures(global_step, figures)
|
||||
|
||||
if global_step % c.print_step == 0:
|
||||
print(
|
||||
|
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
checkpoint = load_fsspec(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
|
@ -207,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
|
|||
|
||||
def main():
|
||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
|
|||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import re
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
|
|||
def read_json_with_comments(json_path):
|
||||
"""for backward compat."""
|
||||
# fallback to json
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r"\\\n", "", input_str)
|
||||
|
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
|
|||
config_dict = {}
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
elif ext == ".json":
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
data = json.loads(input_str)
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# backwards compat.
|
||||
data = read_json_with_comments(config_path)
|
||||
|
|
|
@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
|
|||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||
do_trim_silence (bool):
|
||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
power (float):
|
||||
|
@ -79,7 +83,7 @@ class BaseAudioConfig(Coqpit):
|
|||
preemphasis: float = 0.0
|
||||
ref_level_db: int = 20
|
||||
do_sound_norm: bool = False
|
||||
log_func = "np.log10"
|
||||
log_func: str = "np.log10"
|
||||
# silence trimming
|
||||
do_trim_silence: bool = True
|
||||
trim_db: int = 45
|
||||
|
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
|
|||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = None
|
||||
spec_gain: int = 20
|
||||
do_amp_to_db_linear: bool = True
|
||||
do_amp_to_db_mel: bool = True
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
|
@ -182,51 +188,87 @@ class BaseTrainingConfig(Coqpit):
|
|||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
|
||||
run_name (str):
|
||||
Name of the experiment. This prefixes the output folder name.
|
||||
|
||||
run_description (str):
|
||||
Short description of the experiment.
|
||||
|
||||
epochs (int):
|
||||
Number training epochs. Defaults to 10000.
|
||||
|
||||
batch_size (int):
|
||||
Training batch size.
|
||||
|
||||
eval_batch_size (int):
|
||||
Validation batch size.
|
||||
|
||||
mixed_precision (bool):
|
||||
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
||||
it may also cause numerical unstability in some cases.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, run the scheduler step after each epoch else run it after each model step.
|
||||
|
||||
run_eval (bool):
|
||||
Enable / Disable evaluation (validation) run. Defaults to True.
|
||||
|
||||
test_delay_epochs (int):
|
||||
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
||||
results, hence waiting for a couple of epochs might save some time.
|
||||
|
||||
print_eval (bool):
|
||||
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
||||
the end of the evaluation. Default to ```False```.
|
||||
|
||||
print_step (int):
|
||||
Number of steps required to print the next training log.
|
||||
tb_plot_step (int):
|
||||
|
||||
log_dashboard (str): "tensorboard" or "wandb"
|
||||
Set the experiment tracking tool
|
||||
|
||||
plot_step (int):
|
||||
Number of steps required to log training on Tensorboard.
|
||||
tb_model_param_stats (bool):
|
||||
|
||||
model_param_stats (bool):
|
||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||
Defaults to ```False```.
|
||||
|
||||
project_name (str):
|
||||
Name of the project. Defaults to config.model
|
||||
|
||||
wandb_entity (str):
|
||||
Name of W&B entity/team. Enables collaboration across a team or org.
|
||||
|
||||
log_model_step (int):
|
||||
Number of steps required to log a checkpoint as W&B artifact
|
||||
|
||||
save_step (int):ipt
|
||||
Number of steps required to save the next checkpoint.
|
||||
|
||||
checkpoint (bool):
|
||||
Enable / Disable checkpointing.
|
||||
|
||||
keep_all_best (bool):
|
||||
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
||||
to ```False```.
|
||||
|
||||
keep_after (int):
|
||||
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
||||
to 10000.
|
||||
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
|
||||
output_path (str):
|
||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
||||
All training outputs are saved there.
|
||||
Path for training output folder, either a local file path or other
|
||||
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||
S3 (s3://) paths. The nonexist part of the given path is created
|
||||
automatically. All training artefacts are saved there.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
|
@ -237,14 +279,19 @@ class BaseTrainingConfig(Coqpit):
|
|||
batch_size: int = None
|
||||
eval_batch_size: int = None
|
||||
mixed_precision: bool = False
|
||||
scheduler_after_epoch: bool = False
|
||||
# eval params
|
||||
run_eval: bool = True
|
||||
test_delay_epochs: int = 0
|
||||
print_eval: bool = False
|
||||
# logging
|
||||
dashboard_logger: str = "tensorboard"
|
||||
print_step: int = 25
|
||||
tb_plot_step: int = 100
|
||||
tb_model_param_stats: bool = False
|
||||
plot_step: int = 100
|
||||
model_param_stats: bool = False
|
||||
project_name: str = None
|
||||
log_model_step: int = None
|
||||
wandb_entity: str = None
|
||||
# checkpointing
|
||||
save_step: int = 10000
|
||||
checkpoint: bool = True
|
||||
|
|
|
@ -103,8 +103,8 @@ synthesizer = Synthesizer(
|
|||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
use_multi_speaker = synthesizer.tts_model.speaker_manager is not None and synthesizer.tts_model.num_speakers > 1
|
||||
speaker_manager = synthesizer.tts_model.speaker_manager if hasattr(synthesizer.tts_model, "speaker_manager") else None
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "speaker_manager") and synthesizer.tts_model.num_speakers > 1
|
||||
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
|
||||
# TODO: set this from SpeakerManager
|
||||
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||
app = Flask(__name__)
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, proj_size):
|
||||
|
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
|
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -6,11 +6,11 @@ import re
|
|||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy import signal
|
||||
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
class Storage(object):
|
||||
|
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
|
|||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||
|
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
|
|||
bestmodel_path = "best_model.pth.tar"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||
|
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
|||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
|
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
|
|||
bestmodel_path = "best_model.pth.tar"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
272
TTS/trainer.py
272
TTS/trainer.py
|
@ -1,8 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
@ -12,7 +12,9 @@ import traceback
|
|||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -29,18 +31,20 @@ from TTS.utils.distribute import init_distributed
|
|||
from TTS.utils.generic_utils import (
|
||||
KeepAverage,
|
||||
count_parameters,
|
||||
create_experiment_folder,
|
||||
get_experiment_folder_path,
|
||||
get_git_branch,
|
||||
remove_experiment_folder,
|
||||
set_init_dict,
|
||||
to_cuda,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
||||
multiprocessing.set_start_method("fork")
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
import resource
|
||||
|
@ -48,6 +52,7 @@ if platform.system() != "Windows":
|
|||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
@ -87,7 +92,7 @@ class Trainer:
|
|||
config: Coqpit,
|
||||
output_path: str,
|
||||
c_logger: ConsoleLogger = None,
|
||||
tb_logger: TensorboardLogger = None,
|
||||
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
||||
model: nn.Module = None,
|
||||
cudnn_benchmark: bool = False,
|
||||
) -> None:
|
||||
|
@ -112,7 +117,7 @@ class Trainer:
|
|||
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
|
||||
console logger is used. Defaults to None.
|
||||
|
||||
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
||||
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
|
||||
Defaults to None.
|
||||
|
||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||
|
@ -134,8 +139,8 @@ class Trainer:
|
|||
Running trainer on a config.
|
||||
|
||||
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
||||
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
>>> trainer.fit()
|
||||
|
||||
TODO:
|
||||
|
@ -148,22 +153,24 @@ class Trainer:
|
|||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, tb_logger = process_args(args)
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
|
||||
self.output_path = output_path
|
||||
self.args = args
|
||||
self.config = config
|
||||
|
||||
self.config.output_log_path = output_path
|
||||
# init loggers
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
if tb_logger is None:
|
||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
else:
|
||||
self.tb_logger = tb_logger
|
||||
self.dashboard_logger = dashboard_logger
|
||||
|
||||
if self.dashboard_logger is None:
|
||||
self.dashboard_logger = init_logger(config)
|
||||
|
||||
if not self.config.log_model_step:
|
||||
self.config.log_model_step = self.config.save_step
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
|
@ -173,7 +180,6 @@ class Trainer:
|
|||
self.best_loss = float("inf")
|
||||
self.train_loader = None
|
||||
self.eval_loader = None
|
||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||
|
||||
self.keep_avg_train = None
|
||||
self.keep_avg_eval = None
|
||||
|
@ -184,7 +190,7 @@ class Trainer:
|
|||
# init audio processor
|
||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||
|
||||
# load dataset samples
|
||||
# load data samples
|
||||
# TODO: refactor this
|
||||
if "datasets" in self.config:
|
||||
# load data for `tts` models
|
||||
|
@ -205,6 +211,10 @@ class Trainer:
|
|||
else:
|
||||
self.model = self.get_model(self.config)
|
||||
|
||||
# init multispeaker settings of the model
|
||||
if hasattr(self.model, "init_multispeaker"):
|
||||
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
||||
|
@ -274,9 +284,9 @@ class Trainer:
|
|||
"""
|
||||
# TODO: better model setup
|
||||
try:
|
||||
model = setup_tts_model(config)
|
||||
except ModuleNotFoundError:
|
||||
model = setup_vocoder_model(config)
|
||||
except ModuleNotFoundError:
|
||||
model = setup_tts_model(config)
|
||||
return model
|
||||
|
||||
def restore_model(
|
||||
|
@ -309,7 +319,7 @@ class Trainer:
|
|||
return obj
|
||||
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = torch.load(restore_path)
|
||||
checkpoint = load_fsspec(restore_path)
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
@ -417,7 +427,7 @@ class Trainer:
|
|||
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
||||
config: Coqpit,
|
||||
optimizer_idx: int = None,
|
||||
) -> Tuple[Dict, Dict, int, torch.Tensor]:
|
||||
) -> Tuple[Dict, Dict, int]:
|
||||
"""Perform a forward - backward pass and run the optimizer.
|
||||
|
||||
Args:
|
||||
|
@ -426,7 +436,7 @@ class Trainer:
|
|||
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
||||
scaler (AMPScaler): AMP scaler.
|
||||
criterion (nn.Module): Model's criterion.
|
||||
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
|
||||
config (Coqpit): Model config.
|
||||
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
||||
|
||||
|
@ -436,6 +446,7 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
||||
"""
|
||||
|
||||
step_start_time = time.time()
|
||||
# zero-out optimizer
|
||||
optimizer.zero_grad()
|
||||
|
@ -448,11 +459,11 @@ class Trainer:
|
|||
# skip the rest
|
||||
if outputs is None:
|
||||
step_time = time.time() - step_start_time
|
||||
return None, {}, step_time, 0
|
||||
return None, {}, step_time
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||
|
||||
# set gradient clipping threshold
|
||||
if "grad_clip" in config and config.grad_clip is not None:
|
||||
|
@ -463,7 +474,6 @@ class Trainer:
|
|||
else:
|
||||
grad_clip = 0.0 # meaning no gradient clipping
|
||||
|
||||
# TODO: compute grad norm
|
||||
if grad_clip <= 0:
|
||||
grad_norm = 0
|
||||
|
||||
|
@ -474,15 +484,17 @@ class Trainer:
|
|||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer),
|
||||
grad_clip,
|
||||
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||
)
|
||||
else:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
scale_prev = scaler.get_scale()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
@ -491,13 +503,13 @@ class Trainer:
|
|||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
if grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# setup lr
|
||||
if scheduler is not None and update_lr_scheduler:
|
||||
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
|
||||
scheduler.step()
|
||||
|
||||
# detach losses
|
||||
|
@ -505,7 +517,9 @@ class Trainer:
|
|||
if optimizer_idx is not None:
|
||||
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
||||
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
||||
return outputs, loss_dict, step_time, grad_norm
|
||||
else:
|
||||
loss_dict["grad_norm"] = grad_norm
|
||||
return outputs, loss_dict, step_time
|
||||
|
||||
@staticmethod
|
||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||
|
@ -544,11 +558,10 @@ class Trainer:
|
|||
|
||||
# conteainers to hold model outputs and losses for each optimizer.
|
||||
outputs_per_optimizer = None
|
||||
log_dict = {}
|
||||
loss_dict = {}
|
||||
if not isinstance(self.optimizer, list):
|
||||
# training with a single optimizer
|
||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
||||
)
|
||||
loss_dict.update(loss_dict_new)
|
||||
|
@ -560,25 +573,36 @@ class Trainer:
|
|||
criterion = self.criterion
|
||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
scheduler = self.scheduler[idx]
|
||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||
)
|
||||
# skip the rest if the model returns None
|
||||
total_step_time += step_time
|
||||
outputs_per_optimizer[idx] = outputs
|
||||
# merge loss_dicts from each optimizer
|
||||
# rename duplicates with the optimizer idx
|
||||
# if None, model skipped this optimizer
|
||||
if loss_dict_new is not None:
|
||||
loss_dict.update(loss_dict_new)
|
||||
for k, v in loss_dict_new.items():
|
||||
if k in loss_dict:
|
||||
loss_dict[f"{k}-{idx}"] = v
|
||||
else:
|
||||
loss_dict[k] = v
|
||||
step_time = total_step_time
|
||||
outputs = outputs_per_optimizer
|
||||
|
||||
# update avg stats
|
||||
# update avg runtime stats
|
||||
keep_avg_update = dict()
|
||||
for key, value in log_dict.items():
|
||||
keep_avg_update["avg_" + key] = value
|
||||
keep_avg_update["avg_loader_time"] = loader_time
|
||||
keep_avg_update["avg_step_time"] = step_time
|
||||
self.keep_avg_train.update_values(keep_avg_update)
|
||||
|
||||
# update avg loss stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_train.update_values(update_eval_values)
|
||||
|
||||
# print training progress
|
||||
if self.total_steps_done % self.config.print_step == 0:
|
||||
# log learning rates
|
||||
|
@ -590,33 +614,27 @@ class Trainer:
|
|||
else:
|
||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||
lrs = {"current_lr": current_lr}
|
||||
log_dict.update(lrs)
|
||||
if grad_norm > 0:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
||||
# log run-time stats
|
||||
log_dict.update(
|
||||
loss_dict.update(
|
||||
{
|
||||
"step_time": round(step_time, 4),
|
||||
"loader_time": round(loader_time, 4),
|
||||
}
|
||||
)
|
||||
self.c_logger.print_train_step(
|
||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
|
||||
)
|
||||
|
||||
if self.args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load and don't log every step
|
||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||
iter_stats = log_dict
|
||||
iter_stats.update(loss_dict)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||
if self.total_steps_done % self.config.plot_step == 0:
|
||||
self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
|
||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||
if self.config.checkpoint:
|
||||
# checkpoint the model
|
||||
model_loss = (
|
||||
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
|
||||
)
|
||||
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
|
||||
save_checkpoint(
|
||||
self.config,
|
||||
self.model,
|
||||
|
@ -625,8 +643,14 @@ class Trainer:
|
|||
self.total_steps_done,
|
||||
self.epochs_done,
|
||||
self.output_path,
|
||||
model_loss=model_loss,
|
||||
model_loss=target_avg_loss,
|
||||
)
|
||||
|
||||
if self.total_steps_done % self.config.log_model_step == 0:
|
||||
# log checkpoint as artifact
|
||||
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
|
||||
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||
|
||||
# training visualizations
|
||||
figures, audios = None, None
|
||||
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
||||
|
@ -634,11 +658,13 @@ class Trainer:
|
|||
elif hasattr(self.model, "train_log"):
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.dashboard_logger.train_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.dashboard_logger.flush()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -663,9 +689,17 @@ class Trainer:
|
|||
if self.args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.model_param_stats:
|
||||
self.logger.model_weights(self.model, self.total_steps_done)
|
||||
# scheduler step after the epoch
|
||||
if self.scheduler is not None and self.config.scheduler_after_epoch:
|
||||
if isinstance(self.scheduler, list):
|
||||
for scheduler in self.scheduler:
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
@staticmethod
|
||||
def _model_eval_step(
|
||||
|
@ -701,19 +735,22 @@ class Trainer:
|
|||
Tuple[Dict, Dict]: Model outputs and losses.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs_per_optimizer = None
|
||||
outputs = []
|
||||
loss_dict = {}
|
||||
if not isinstance(self.optimizer, list):
|
||||
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
||||
else:
|
||||
outputs_per_optimizer = [None] * len(self.optimizer)
|
||||
outputs = [None] * len(self.optimizer)
|
||||
for idx, _ in enumerate(self.optimizer):
|
||||
criterion = self.criterion
|
||||
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||
outputs_per_optimizer[idx] = outputs
|
||||
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||
outputs[idx] = outputs_
|
||||
|
||||
if loss_dict_new is not None:
|
||||
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
|
||||
loss_dict.update(loss_dict_new)
|
||||
outputs = outputs_per_optimizer
|
||||
|
||||
loss_dict = self._detach_loss_dict(loss_dict)
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
|
@ -755,28 +792,35 @@ class Trainer:
|
|||
elif hasattr(self.model, "eval_log"):
|
||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
|
||||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if self.eval_loader is None:
|
||||
self.eval_loader = self.get_eval_dataloader(
|
||||
self.ap,
|
||||
self.data_eval,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
@ -802,10 +846,13 @@ class Trainer:
|
|||
"""Where the ✨️magic✨️ happens..."""
|
||||
try:
|
||||
self._fit()
|
||||
self.dashboard_logger.finish()
|
||||
except KeyboardInterrupt:
|
||||
self.callbacks.on_keyboard_interrupt()
|
||||
# if the output folder is empty remove the run.
|
||||
remove_experiment_folder(self.output_path)
|
||||
# finish the wandb run and sync data
|
||||
self.dashboard_logger.finish()
|
||||
# stop without error signal
|
||||
try:
|
||||
sys.exit(0)
|
||||
|
@ -816,10 +863,33 @@ class Trainer:
|
|||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||
"""Pick the target loss to compare models"""
|
||||
target_avg_loss = None
|
||||
|
||||
# return if target loss defined in the model config
|
||||
if "target_loss" in self.config and self.config.target_loss:
|
||||
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||
|
||||
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||
if isinstance(self.optimizer, list):
|
||||
target_avg_loss = 0
|
||||
for idx in range(len(self.optimizer)):
|
||||
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||
target_avg_loss /= len(self.optimizer)
|
||||
else:
|
||||
target_avg_loss = keep_avg_target["avg_loss"]
|
||||
return target_avg_loss
|
||||
|
||||
def save_best_model(self) -> None:
|
||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||
|
||||
# set the target loss to choose the best model
|
||||
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
|
||||
|
||||
# save the model and update the best_loss
|
||||
self.best_loss = save_best_model(
|
||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||
target_loss_dict,
|
||||
self.best_loss,
|
||||
self.config,
|
||||
self.model,
|
||||
|
@ -834,9 +904,16 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||
)
|
||||
handlers = [logging.StreamHandler()]
|
||||
|
||||
# Only add a log file if the output location is local due to poor
|
||||
# support for writing logs to file-like objects.
|
||||
parsed_url = urlparse(log_file)
|
||||
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||
handlers.append(logging.FileHandler(schemeless_path))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
|
@ -920,28 +997,33 @@ class Trainer:
|
|||
return criterion
|
||||
|
||||
|
||||
def init_arguments():
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def get_last_checkpoint(path):
|
||||
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
|
||||
Args:
|
||||
path (list): Path to files to be compared.
|
||||
path: Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
|
||||
Returns:
|
||||
last_checkpoint (str): Last checkpoint filename.
|
||||
Path to the last checkpoint
|
||||
Path to best checkpoint
|
||||
"""
|
||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||
scheme = urlparse(path).scheme
|
||||
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
|
@ -963,7 +1045,7 @@ def get_last_checkpoint(path):
|
|||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = torch.load(last_model)["step"]
|
||||
last_model_num = load_fsspec(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
|
@ -997,8 +1079,8 @@ def process_args(args, config=None):
|
|||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard logging.
|
||||
|
||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
|
@ -1012,6 +1094,7 @@ def process_args(args, config=None):
|
|||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
|
||||
# init config if not already defined
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
|
@ -1030,12 +1113,12 @@ def process_args(args, config=None):
|
|||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = create_experiment_folder(config.output_path, config.run_name)
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
config.output_log_path = experiment_path
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
dashboard_logger = None
|
||||
if args.rank == 0:
|
||||
os.makedirs(audio_path, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
|
@ -1043,17 +1126,20 @@ def process_args(args, config=None):
|
|||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters_config"):
|
||||
if config.has("characters") and config.characters is None:
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(experiment_path, 0o775)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
||||
dashboard_logger = init_logger(config)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||
|
@ -1063,5 +1149,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
|||
else:
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
||||
|
|
|
@ -13,12 +13,16 @@ class GSTConfig(Coqpit):
|
|||
Args:
|
||||
gst_style_input_wav (str):
|
||||
Path to the wav file used to define the style of the output speech at inference. Defaults to None.
|
||||
|
||||
gst_style_input_weights (dict):
|
||||
Defines the weights for each style token used at inference. Defaults to None.
|
||||
|
||||
gst_embedding_dim (int):
|
||||
Defines the size of the GST embedding vector dimensions. Defaults to 256.
|
||||
|
||||
gst_num_heads (int):
|
||||
Number of attention heads used by the multi-head attention. Defaults to 4.
|
||||
|
||||
gst_num_style_tokens (int):
|
||||
Number of style token vectors. Defaults to 10.
|
||||
"""
|
||||
|
@ -51,17 +55,23 @@ class CharactersConfig(Coqpit):
|
|||
Args:
|
||||
pad (str):
|
||||
characters in place of empty padding. Defaults to None.
|
||||
|
||||
eos (str):
|
||||
characters showing the end of a sentence. Defaults to None.
|
||||
|
||||
bos (str):
|
||||
characters showing the beginning of a sentence. Defaults to None.
|
||||
|
||||
characters (str):
|
||||
character set used by the model. Characters not in this list are ignored when converting input text to
|
||||
a list of sequence IDs. Defaults to None.
|
||||
|
||||
punctuations (str):
|
||||
characters considered as punctuation as parsing the input sentence. Defaults to None.
|
||||
|
||||
phonemes (str):
|
||||
characters considered as parsing phonemes. Defaults to None.
|
||||
|
||||
unique (bool):
|
||||
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
|
||||
models trained with character lists with duplicates.
|
||||
|
@ -95,54 +105,78 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
Args:
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config object instance.
|
||||
|
||||
use_phonemes (bool):
|
||||
enable / disable phoneme use.
|
||||
|
||||
use_espeak_phonemes (bool):
|
||||
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
|
||||
|
||||
compute_input_seq_cache (bool):
|
||||
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
|
||||
the training, It allows faster data loader time and precise limitation with `max_seq_len` and
|
||||
`min_seq_len`.
|
||||
|
||||
text_cleaner (str):
|
||||
Name of the text cleaner used for cleaning and formatting transcripts.
|
||||
|
||||
enable_eos_bos_chars (bool):
|
||||
enable / disable the use of eos and bos characters.
|
||||
|
||||
test_senteces_file (str):
|
||||
Path to a txt file that has sentences used at test time. The file must have a sentence per line.
|
||||
|
||||
phoneme_cache_path (str):
|
||||
Path to the output folder caching the computed phonemes for each sample.
|
||||
|
||||
characters (CharactersConfig):
|
||||
Instance of a CharactersConfig class.
|
||||
|
||||
batch_group_size (int):
|
||||
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||
prevent using the same batches for each epoch.
|
||||
|
||||
loss_masking (bool):
|
||||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
|
||||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
||||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
"""
|
||||
|
@ -166,6 +200,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
min_seq_len: int = 1
|
||||
max_seq_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
use_noise_augment: bool = False
|
||||
add_blank: bool = False
|
||||
# dataset
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.vits import VitsArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitsConfig(BaseTTSConfig):
|
||||
"""Defines parameters for VITS End2End TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (VitsArgs):
|
||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
|
||||
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator. Defaults to 0.0002.
|
||||
|
||||
lr_disc (float):
|
||||
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||
|
||||
lr_scheduler_gen (str):
|
||||
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
lr_scheduler_disc (str):
|
||||
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||
|
||||
optimizer (str):
|
||||
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||
`torch.optim.*`. Defaults to `AdamW`.
|
||||
|
||||
kl_loss_alpha (float):
|
||||
Loss weight for KL loss. Defaults to 1.0.
|
||||
|
||||
disc_loss_alpha (float):
|
||||
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||
|
||||
gen_loss_alpha (float):
|
||||
Loss weight for the generator loss. Defaults to 1.0.
|
||||
|
||||
feat_loss_alpha (float):
|
||||
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||
|
||||
mel_loss_alpha (float):
|
||||
Loss weight for the mel loss. Defaults to 45.0.
|
||||
|
||||
return_wav (bool):
|
||||
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum text length to be considered for training. Defaults to `13`.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum text length to be considered for training. Defaults to `500`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
add_blank (bool):
|
||||
If true, a blank token is added in between every character. Defaults to `True`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used for testing.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs import VitsConfig
|
||||
>>> config = VitsConfig()
|
||||
"""
|
||||
|
||||
model: str = "vits"
|
||||
# model specific params
|
||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||
|
||||
# optimizer
|
||||
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
|
||||
lr_gen: float = 0.0002
|
||||
lr_disc: float = 0.0002
|
||||
lr_scheduler_gen: str = "ExponentialLR"
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "ExponentialLR"
|
||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||
scheduler_after_epoch: bool = True
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
||||
|
||||
# loss params
|
||||
kl_loss_alpha: float = 1.0
|
||||
disc_loss_alpha: float = 1.0
|
||||
gen_loss_alpha: float = 1.0
|
||||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 500
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
|
@ -23,7 +23,9 @@ class TTSDataset(Dataset):
|
|||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
characters: Dict = None,
|
||||
custom_symbols: List = None,
|
||||
add_blank: bool = False,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
|
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
|
|||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||
set of symbols need to pass it here. Defaults to `None`.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
|
|||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.characters = characters
|
||||
self.custom_symbols = custom_symbols
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
|
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
|
|||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if self.verbose:
|
||||
|
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
||||
def _generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
||||
text,
|
||||
[cleaners],
|
||||
language=language,
|
||||
enable_eos_bos=False,
|
||||
custom_symbols=custom_symbols,
|
||||
tp=characters,
|
||||
add_blank=add_blank,
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
|
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
|
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
|
|||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
|
@ -173,6 +191,7 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
text, wav_file, speaker_name = item
|
||||
attn = None
|
||||
raw_text = text
|
||||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
|
||||
|
@ -189,13 +208,19 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
)
|
||||
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
|
@ -209,9 +234,10 @@ class TTSDataset(Dataset):
|
|||
# return a different sample if the phonemized
|
||||
# text is longer than the threshold
|
||||
# TODO: find a better fix
|
||||
return self.load_data(100)
|
||||
return self.load_data(self.rescue_item_idx)
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
"text": text,
|
||||
"wav": wav,
|
||||
"attn": attn,
|
||||
|
@ -238,7 +264,13 @@ class TTSDataset(Dataset):
|
|||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
@ -249,6 +281,7 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
|
@ -329,6 +362,7 @@ class TTSDataset(Dataset):
|
|||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
||||
|
||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get pre-computed d-vectors
|
||||
|
@ -347,6 +381,14 @@ class TTSDataset(Dataset):
|
|||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
# lengths adjusted by the reduction factor
|
||||
mel_lengths_adjusted = [
|
||||
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||
if m.shape[1] % self.outputs_per_step
|
||||
else m.shape[1]
|
||||
for m in mel
|
||||
]
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
||||
|
||||
|
@ -385,6 +427,20 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
linear = None
|
||||
|
||||
# format waveforms
|
||||
wav_padded = None
|
||||
if self.return_wav:
|
||||
wav_lengths = [w.shape[0] for w in wav]
|
||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||
wav_lengths = torch.LongTensor(wav_lengths)
|
||||
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
||||
for i, w in enumerate(wav):
|
||||
mel_length = mel_lengths_adjusted[i]
|
||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||
w = w[: mel_length * self.ap.hop_length]
|
||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||
wav_padded.transpose_(1, 2)
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]["attn"] is not None:
|
||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||
|
@ -397,6 +453,7 @@ class TTSDataset(Dataset):
|
|||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||
else:
|
||||
attns = None
|
||||
# TODO: return dictionary
|
||||
return (
|
||||
text,
|
||||
text_lenghts,
|
||||
|
@ -409,6 +466,8 @@ class TTSDataset(Dataset):
|
|||
d_vectors,
|
||||
speaker_ids,
|
||||
attns,
|
||||
wav_padded,
|
||||
raw_text,
|
||||
)
|
||||
|
||||
raise TypeError(
|
||||
|
|
|
@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LayerNorm2(nn.Module):
|
||||
"""Layer norm for the 2nd dimension of the input using torch primitive.
|
||||
Args:
|
||||
channels (int): number of channels (2nd dimension) of the input.
|
||||
eps (float): to prevent 0 division
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
- output: (B, C, T)
|
||||
"""
|
||||
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||
"""Normalize each channel separately over time and batch."""
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
|
|||
dropout_p (float): Dropout rate used after each conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.in_channels = in_channels
|
||||
|
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
|
|||
self.norm_2 = LayerNorm(hidden_channels)
|
||||
# output layer
|
||||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||
if cond_channels is not None and cond_channels != 0:
|
||||
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
def forward(self, x, x_mask, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
|
|
|
@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
|||
::
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|---------------> conv1d_1x1 ------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.glow import LayerNorm
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
|
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
|
|||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
x = torch.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _pad_shape(padding):
|
||||
l = padding[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
|
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
|
|||
If default, relative encoding is disabled and it is a regular transformer.
|
||||
Defaults to None.
|
||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
|
||||
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
hidden_channels_ffn: int,
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
kernel_size=1,
|
||||
dropout_p=0.0,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
rel_attn_window_size: int = None,
|
||||
input_length: int = None,
|
||||
layer_norm_type: str = "1",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
|
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
|
|||
input_length=input_length,
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
if layer_norm_type == "1":
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
elif layer_norm_type == "2":
|
||||
self.norm_layers_1.append(LayerNorm2(hidden_channels))
|
||||
else:
|
||||
raise ValueError(" [!] Unknown layer norm type")
|
||||
|
||||
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||
if layer_norm_type == "1":
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||
elif layer_norm_type == "2":
|
||||
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||
else:
|
||||
raise ValueError(" [!] Unknown layer norm type")
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
|
|
|
@ -2,11 +2,13 @@ import math
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
|
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
|
|||
+ self.mdn_alpha * mdn_loss
|
||||
)
|
||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
|
||||
class VitsGeneratorLoss(nn.Module):
|
||||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
self.kl_loss_alpha = c.kl_loss_alpha
|
||||
self.gen_loss_alpha = c.gen_loss_alpha
|
||||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
c.audio.hop_length,
|
||||
c.audio.win_length,
|
||||
sample_rate=c.audio.sample_rate,
|
||||
mel_fmin=c.audio.mel_fmin,
|
||||
mel_fmax=c.audio.mel_fmax,
|
||||
n_mels=c.audio.num_mels,
|
||||
use_mel=True,
|
||||
do_amp_to_db=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
@staticmethod
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
l = kl / torch.sum(z_mask)
|
||||
return l
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
waveform_hat,
|
||||
z_p,
|
||||
logs_q,
|
||||
m_p,
|
||||
logs_p,
|
||||
z_len,
|
||||
scores_disc_fake,
|
||||
feats_disc_fake,
|
||||
feats_disc_real,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
- wavefrom: :math:`[B, 1, T]`
|
||||
- waveform_hat: :math:`[B, 1, T]`
|
||||
- z_p: :math:`[B, C, T]`
|
||||
- logs_q: :math:`[B, C, T]`
|
||||
- m_p: :math:`[B, C, T]`
|
||||
- logs_p: :math:`[B, C, T]`
|
||||
- z_len: :math:`[B]`
|
||||
- scores_disc_fake[i]: :math:`[B, C]`
|
||||
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
|
||||
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
|
||||
"""
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
z_mask = sequence_mask(z_len).float()
|
||||
# compute mel spectrograms from the waveforms
|
||||
mel = self.stft(waveform)
|
||||
mel_hat = self.stft(waveform_hat)
|
||||
# compute losses
|
||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
return_dict["loss_feat"] = loss_feat
|
||||
return_dict["loss_mel"] = loss_mel
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class VitsDiscriminatorLoss(nn.Module):
|
||||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
self.disc_loss_alpha = c.disc_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def discriminator_loss(scores_real, scores_fake):
|
||||
loss = 0
|
||||
real_losses = []
|
||||
fake_losses = []
|
||||
for dr, dg in zip(scores_real, scores_fake):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
real_loss = torch.mean((1 - dr) ** 2)
|
||||
fake_loss = torch.mean(dg ** 2)
|
||||
loss += real_loss + fake_loss
|
||||
real_losses.append(real_loss.item())
|
||||
fake_losses.append(fake_loss.item())
|
||||
|
||||
return loss, real_losses, fake_losses
|
||||
|
||||
def forward(self, scores_disc_real, scores_disc_fake):
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + loss_disc
|
||||
return_dict["loss_disc"] = loss_disc
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -8,10 +8,10 @@ class GST(nn.Module):
|
|||
|
||||
See https://arxiv.org/pdf/1803.09017"""
|
||||
|
||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None):
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim)
|
||||
|
||||
def forward(self, inputs, speaker_embedding=None):
|
||||
enc_out = self.encoder(inputs)
|
||||
|
@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module):
|
|||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None):
|
||||
def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
||||
super().__init__()
|
||||
|
||||
self.query_dim = embedding_dim // 2
|
||||
self.query_dim = gst_embedding_dim // 2
|
||||
|
||||
if d_vector_dim:
|
||||
self.query_dim += d_vector_dim
|
||||
|
||||
self.key_dim = embedding_dim // num_heads
|
||||
self.key_dim = gst_embedding_dim // num_heads
|
||||
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
||||
self.attention = MultiHeadAttention(
|
||||
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
|
||||
query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
|
||||
|
||||
Args:
|
||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: discriminator scores.
|
||||
List[Tensor]: list of features from the convolutiona layers.
|
||||
"""
|
||||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = torch.nn.functional.leaky_relu(x, 0.1)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, feat
|
||||
|
||||
|
||||
class VitsDiscriminator(nn.Module):
|
||||
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
|
||||
|
||||
::
|
||||
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|
||||
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
|
||||
|
||||
Args:
|
||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
|
||||
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
score_sd, feats_sd = self.sd(x)
|
||||
return scores + [score_sd], feats + [feats_sd]
|
|
@ -0,0 +1,271 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.glow_tts.glow import WN
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
hidden_channels_ffn: int,
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
kernel_size: int,
|
||||
dropout_p: float,
|
||||
):
|
||||
"""Text Encoder for VITS model.
|
||||
|
||||
Args:
|
||||
n_vocab (int): Number of characters for the embedding layer.
|
||||
out_channels (int): Number of channels for the output.
|
||||
hidden_channels (int): Number of channels for the hidden layers.
|
||||
hidden_channels_ffn (int): Number of channels for the convolutional layers.
|
||||
num_heads (int): Number of attention heads for the Transformer layers.
|
||||
num_layers (int): Number of Transformer layers.
|
||||
kernel_size (int): Kernel size for the FFN layers in Transformer network.
|
||||
dropout_p (float): Dropout rate for the Transformer layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
|
||||
self.encoder = RelativePositionTransformer(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
hidden_channels_ffn=hidden_channels_ffn,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
layer_norm_type="2",
|
||||
rel_attn_window_size=4,
|
||||
)
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_length: :math:`[B]`
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return x, m, logs, x_mask
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
dropout_p=0,
|
||||
cond_channels=0,
|
||||
mean_only=False,
|
||||
):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
# input layer
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
# coupling layers
|
||||
self.enc = WN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
dropout_p=dropout_p,
|
||||
c_in_channels=cond_channels,
|
||||
)
|
||||
# output layer
|
||||
# Initializing last layer to 0 makes the affine coupling layers
|
||||
# do nothing at first. This helps with training stability
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
log_scale = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(log_scale) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(log_scale, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingBlocks(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
num_layers: int,
|
||||
num_flows=4,
|
||||
cond_channels=0,
|
||||
):
|
||||
"""Redisual Coupling blocks for VITS flow layers.
|
||||
|
||||
Args:
|
||||
channels (int): Number of input and output tensor channels.
|
||||
hidden_channels (int): Number of hidden network channels.
|
||||
kernel_size (int): Kernel size of the WaveNet layers.
|
||||
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||
num_layers (int): Number of the WaveNet layers.
|
||||
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
|
||||
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_layers = num_layers
|
||||
self.num_flows = num_flows
|
||||
self.cond_channels = cond_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for _ in range(num_flows):
|
||||
self.flows.append(
|
||||
ResidualCouplingBlock(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
cond_channels=cond_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
x = torch.flip(x, [1])
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = torch.flip(x, [1])
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
num_layers: int,
|
||||
cond_channels=0,
|
||||
):
|
||||
"""Posterior Encoder of VITS model.
|
||||
|
||||
::
|
||||
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
out_channels (int): Number of output tensor channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of the WaveNet convolution layers.
|
||||
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||
num_layers (int): Number of the WaveNet layers.
|
||||
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_layers = num_layers
|
||||
self.cond_channels = cond_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = WN(
|
||||
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_lengths: :math:`[B, 1]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
|
||||
return z, mean, log_scale, x_mask
|
|
@ -0,0 +1,276 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
class DilatedDepthSeparableConv(nn.Module):
|
||||
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
|
||||
"""Dilated Depth-wise Separable Convolution module.
|
||||
|
||||
::
|
||||
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|
||||
|-------------------------------------------------------------------------------------^
|
||||
|
||||
Args:
|
||||
channels ([type]): [description]
|
||||
kernel_size ([type]): [description]
|
||||
num_layers ([type]): [description]
|
||||
dropout_p (float, optional): [description]. Defaults to 0.0.
|
||||
|
||||
Returns:
|
||||
torch.tensor: Network output masked by the input sequence mask.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm2(channels))
|
||||
self.norms_2.append(LayerNorm2(channels))
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.num_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.dropout(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
"""Element-wise affine transform like no-population stats BatchNorm alternative.
|
||||
|
||||
Args:
|
||||
channels (int): Number of input tensor channels.
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.translation = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
if not reverse:
|
||||
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
|
||||
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
"""Dilated depth separable convolutional based spline flow.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
hidden_channels (int): Number of in network channels.
|
||||
kernel_size (int): Convolutional kernel size.
|
||||
num_layers (int): Number of convolutional layers.
|
||||
num_bins (int, optional): Number of spline bins. Defaults to 10.
|
||||
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
num_layers: int,
|
||||
num_bins=10,
|
||||
tail_bound=5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_bins = num_bins
|
||||
self.tail_bound = tail_bound
|
||||
self.hidden_channels = hidden_channels
|
||||
self.half_channels = in_channels // 2
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
|
||||
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||
x1,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=reverse,
|
||||
tails="linear",
|
||||
tail_bound=self.tail_bound,
|
||||
)
|
||||
|
||||
x = torch.cat([x0, x1], 1) * x_mask
|
||||
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
||||
if not reverse:
|
||||
return x, logdet
|
||||
return x
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
"""Stochastic duration predictor with Spline Flows.
|
||||
|
||||
It applies Variational Dequantization and Variationsl Data Augmentation.
|
||||
|
||||
Paper:
|
||||
SDP: https://arxiv.org/pdf/2106.06103.pdf
|
||||
Spline Flow: https://arxiv.org/abs/1906.04032
|
||||
|
||||
::
|
||||
## Inference
|
||||
|
||||
x -> TextCondEncoder() -> Flow() -> dr_hat
|
||||
noise ----------------------^
|
||||
|
||||
## Training
|
||||
|---------------------|
|
||||
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
|
||||
d -> DurCondEncoder() -> ^ |
|
||||
|------------------------------------------------------------------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of convolutional layers.
|
||||
dropout_p (float): Dropout rate.
|
||||
num_flows (int, optional): Number of flow blocks. Defaults to 4.
|
||||
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# condition encoder text
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
|
||||
# posterior encoder
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(ElementwiseAffine(2))
|
||||
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||
|
||||
# condition encoder duration
|
||||
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
|
||||
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
|
||||
# flow layers
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(ElementwiseAffine(2))
|
||||
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||
|
||||
if cond_channels != 0 and cond_channels is not None:
|
||||
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- dr: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C]`
|
||||
"""
|
||||
# condition encoder text
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert dr is not None
|
||||
|
||||
# condition encoder duration
|
||||
h = self.post_pre(dr)
|
||||
h = self.post_convs(h, x_mask)
|
||||
h = self.post_proj(h) * x_mask
|
||||
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = noise
|
||||
|
||||
# posterior encoder
|
||||
logdet_tot_q = 0.0
|
||||
for idx, flow in enumerate(self.post_flows):
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
|
||||
logdet_tot_q = logdet_tot_q + logdet_q
|
||||
if idx > 0:
|
||||
z_q = torch.flip(z_q, [1])
|
||||
|
||||
z_u, z_v = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (dr - u) * x_mask
|
||||
|
||||
# posterior encoder - neg log likelihood
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
nll_posterior_encoder = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
)
|
||||
|
||||
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
|
||||
logdet_tot = torch.sum(-z0, [1, 2])
|
||||
z = torch.cat([z0, z_v], 1)
|
||||
|
||||
# flow layers
|
||||
for idx, flow in enumerate(flows):
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
if idx > 0:
|
||||
z = torch.flip(z, [1])
|
||||
|
||||
# flow layers - neg log likelihood
|
||||
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll_flow_layers + nll_posterior_encoder
|
||||
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = torch.flip(z, [1])
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
|
||||
z0, _ = torch.split(z, [1, 1], 1)
|
||||
logw = z0
|
||||
return logw
|
|
@ -0,0 +1,203 @@
|
|||
# adopted from https://github.com/bayesiains/nflows
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs,
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails="linear",
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == "linear":
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||
|
||||
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound,
|
||||
right=tail_bound,
|
||||
bottom=-tail_bound,
|
||||
top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0.0,
|
||||
right=1.0,
|
||||
bottom=0.0,
|
||||
top=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError("Input to a transform is not within its domain")
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin width too large for the number of bins")
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin height too large for the number of bins")
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) + input_heights * (input_delta - input_derivatives)
|
||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
)
|
||||
c = -input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
|
@ -4,20 +4,23 @@ from TTS.utils.generic_utils import find_module
|
|||
|
||||
def setup_model(config):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
|
||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||
# define set of characters used by the model
|
||||
if config.characters is not None:
|
||||
# set characters from config
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name
|
||||
if hasattr(MyModel, "make_symbols"):
|
||||
symbols = MyModel.make_symbols(config)
|
||||
else:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
|
||||
|
||||
if config.use_phonemes:
|
||||
symbols = phonemes
|
||||
# use default characters and assign them to config
|
||||
config.characters = parse_symbols()
|
||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
# consider special `blank` character if `add_blank` is set True
|
||||
num_chars = num_chars + getattr(config, "add_blank", False)
|
||||
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||
config.num_chars = num_chars
|
||||
# compatibility fix
|
||||
if "model_params" in config:
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
|
@ -75,9 +76,6 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
# init multi-speaker layers
|
||||
self.init_multispeaker(config)
|
||||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
|
@ -113,7 +111,7 @@ class BaseTacotron(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if "r" in state:
|
||||
self.decoder.set_r(state["r"])
|
||||
|
@ -236,6 +234,7 @@ class BaseTacotron(BaseTTS):
|
|||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
# multiply each style token with a weight
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
@ -247,8 +246,10 @@ class BaseTacotron(BaseTTS):
|
|||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
# ignore style token and return zero tensor
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
# compute style tokens
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
|
|||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.0
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
|
|||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
if data is not None or self.speaker_manager.speaker_ids:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
else:
|
||||
self.num_speakers = (
|
||||
config.num_speakers
|
||||
if "num_speakers" in config and config.num_speakers != 0
|
||||
else self.speaker_manager.num_speakers
|
||||
)
|
||||
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
|
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
|
|||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
linear_input = batch[3]
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
|
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
|
|||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
waveform = batch[11]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
|
|||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
|
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
|
|||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup custom symbols if needed
|
||||
custom_symbols = None
|
||||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron",
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
custom_symbols=custom_symbols,
|
||||
add_blank=config["add_blank"],
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
|
@ -185,8 +213,21 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
dataset.items = self.train_data_items
|
||||
else:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
|
||||
# TODO: find a more efficient solution
|
||||
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
||||
if is_eval:
|
||||
self.eval_data_items = dataset.items
|
||||
else:
|
||||
self.train_data_items = dataset.items
|
||||
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
@ -216,7 +257,7 @@ class BaseTTS(BaseModel):
|
|||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
|
@ -228,9 +269,12 @@ class BaseTTS(BaseModel):
|
|||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return test_figures, test_audios
|
||||
|
|
|
@ -12,14 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class GlowTTS(BaseTTS):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
"""GlowTTS model.
|
||||
|
||||
Paper abstract:
|
||||
Paper::
|
||||
https://arxiv.org/abs/2005.11129
|
||||
|
||||
Paper abstract::
|
||||
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
||||
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
||||
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
||||
|
@ -144,7 +149,6 @@ class GlowTTS(BaseTTS):
|
|||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
|
@ -361,12 +365,49 @@ class GlowTTS(BaseTTS):
|
|||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, ap):
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
outputs = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs["outputs"]["model_outputs"], ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||
if y_max_length is not None:
|
||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||
|
@ -382,7 +423,7 @@ class GlowTTS(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -105,7 +106,7 @@ class SpeedySpeech(BaseTTS):
|
|||
if isinstance(config.model_args.length_scale, int)
|
||||
else config.model_args.length_scale
|
||||
)
|
||||
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
|
||||
self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
|
||||
self.encoder = Encoder(
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.hidden_channels,
|
||||
|
@ -227,6 +228,7 @@ class SpeedySpeech(BaseTTS):
|
|||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -306,7 +308,7 @@ class SpeedySpeech(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -23,19 +23,19 @@ class Tacotron(BaseTacotron):
|
|||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_chars, self.config = self.get_characters(config)
|
||||
chars, self.config = self.get_characters(config)
|
||||
config.num_chars = self.num_chars = len(chars)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||
# on the number of speakers infered from the dataset.
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
|
@ -75,13 +75,11 @@ class Tacotron(BaseTacotron):
|
|||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
|
@ -106,7 +104,9 @@ class Tacotron(BaseTacotron):
|
|||
self.max_decoder_steps,
|
||||
)
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
def forward( # pylint: disable=dangerous-default-value
|
||||
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
|
@ -115,6 +115,7 @@ class Tacotron(BaseTacotron):
|
|||
mel_lengths: [B]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
"""
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
inputs = self.embedding(text)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
@ -125,12 +126,10 @@ class Tacotron(BaseTacotron):
|
|||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
|
@ -182,7 +181,7 @@ class Tacotron(BaseTacotron):
|
|||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
if not self.use_d_vector_file:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
||||
# reshape embedded_speakers
|
||||
|
|
|
@ -23,7 +23,7 @@ class Tacotron2(BaseTacotron):
|
|||
super().__init__(config)
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
config.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# pass all config fields to `self`
|
||||
|
@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron):
|
|||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||
# on the number of speakers infered from the dataset.
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
|
@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron):
|
|||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
|
@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron):
|
|||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
|
@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
def forward( # pylint: disable=dangerous-default-value
|
||||
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
|
@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
|
@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron):
|
|||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
if not self.use_d_vector_file:
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||
# reshape embedded_speakers
|
||||
if embedded_speakers.ndim == 1:
|
||||
|
|
|
@ -0,0 +1,767 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
||||
"""Segment each sample in a batch based on the provided segment indices"""
|
||||
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
index_start = segment_indices[i]
|
||||
index_end = index_start + segment_size
|
||||
segments[i] = x[i, :, index_start:index_end]
|
||||
return segments
|
||||
|
||||
|
||||
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
|
||||
"""Create random segments based on the input lengths."""
|
||||
B, _, T = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = T
|
||||
max_idxs = x_lengths - segment_size + 1
|
||||
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||
ret = segment(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitsArgs(Coqpit):
|
||||
"""VITS model arguments.
|
||||
|
||||
Args:
|
||||
|
||||
num_chars (int):
|
||||
Number of characters in the vocabulary. Defaults to 100.
|
||||
|
||||
out_channels (int):
|
||||
Number of output channels. Defaults to 513.
|
||||
|
||||
spec_segment_size (int):
|
||||
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
|
||||
|
||||
hidden_channels (int):
|
||||
Number of hidden channels of the model. Defaults to 192.
|
||||
|
||||
hidden_channels_ffn_text_encoder (int):
|
||||
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
|
||||
|
||||
num_heads_text_encoder (int):
|
||||
Number of attention heads of the text encoder transformer. Defaults to 2.
|
||||
|
||||
num_layers_text_encoder (int):
|
||||
Number of transformer layers in the text encoder. Defaults to 6.
|
||||
|
||||
kernel_size_text_encoder (int):
|
||||
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
|
||||
|
||||
dropout_p_text_encoder (float):
|
||||
Dropout rate of the text encoder. Defaults to 0.1.
|
||||
|
||||
dropout_p_duration_predictor (float):
|
||||
Dropout rate of the duration predictor. Defaults to 0.1.
|
||||
|
||||
kernel_size_posterior_encoder (int):
|
||||
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
|
||||
|
||||
dilatation_posterior_encoder (int):
|
||||
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
|
||||
|
||||
num_layers_posterior_encoder (int):
|
||||
Number of posterior encoder's WaveNet layers. Defaults to 16.
|
||||
|
||||
kernel_size_flow (int):
|
||||
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
|
||||
|
||||
dilatation_flow (int):
|
||||
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
|
||||
|
||||
num_layers_flow (int):
|
||||
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
|
||||
|
||||
resblock_type_decoder (str):
|
||||
Type of the residual block in the decoder network. Defaults to "1".
|
||||
|
||||
resblock_kernel_sizes_decoder (List[int]):
|
||||
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
|
||||
|
||||
resblock_dilation_sizes_decoder (List[List[int]]):
|
||||
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
|
||||
|
||||
upsample_rates_decoder (List[int]):
|
||||
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
|
||||
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
|
||||
|
||||
upsample_initial_channel_decoder (int):
|
||||
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
|
||||
|
||||
upsample_kernel_sizes_decoder (List[int]):
|
||||
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
||||
|
||||
use_sdp (int):
|
||||
Use Stochastic Duration Predictor. Defaults to True.
|
||||
|
||||
noise_scale (float):
|
||||
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
|
||||
|
||||
inference_noise_scale (float):
|
||||
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
|
||||
|
||||
length_scale (int):
|
||||
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
|
||||
|
||||
noise_scale_dp (float):
|
||||
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
|
||||
|
||||
inference_noise_scale_dp (float):
|
||||
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
|
||||
|
||||
max_inference_len (int):
|
||||
Maximum inference length to limit the memory use. Defaults to None.
|
||||
|
||||
init_discriminator (bool):
|
||||
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
|
||||
|
||||
use_spectral_norm_disriminator (bool):
|
||||
Use spectral normalization over weight norm in the discriminator. Defaults to False.
|
||||
|
||||
use_speaker_embedding (bool):
|
||||
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
|
||||
|
||||
num_speakers (int):
|
||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||
|
||||
speakers_file (str):
|
||||
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||
|
||||
speaker_embedding_channels (int):
|
||||
Number of speaker embedding channels. Defaults to 256.
|
||||
|
||||
use_d_vector_file (bool):
|
||||
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||
|
||||
d_vector_dim (int):
|
||||
Number of d-vector channels. Defaults to 0.
|
||||
|
||||
detach_dp_input (bool):
|
||||
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||
"""
|
||||
|
||||
num_chars: int = 100
|
||||
out_channels: int = 513
|
||||
spec_segment_size: int = 32
|
||||
hidden_channels: int = 192
|
||||
hidden_channels_ffn_text_encoder: int = 768
|
||||
num_heads_text_encoder: int = 2
|
||||
num_layers_text_encoder: int = 6
|
||||
kernel_size_text_encoder: int = 3
|
||||
dropout_p_text_encoder: int = 0.1
|
||||
dropout_p_duration_predictor: int = 0.1
|
||||
kernel_size_posterior_encoder: int = 5
|
||||
dilation_rate_posterior_encoder: int = 1
|
||||
num_layers_posterior_encoder: int = 16
|
||||
kernel_size_flow: int = 5
|
||||
dilation_rate_flow: int = 1
|
||||
num_layers_flow: int = 4
|
||||
resblock_type_decoder: int = "1"
|
||||
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||
upsample_initial_channel_decoder: int = 512
|
||||
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
use_sdp: int = True
|
||||
noise_scale: float = 1.0
|
||||
inference_noise_scale: float = 0.667
|
||||
length_scale: int = 1
|
||||
noise_scale_dp: float = 1.0
|
||||
inference_noise_scale_dp: float = 0.8
|
||||
max_inference_len: int = None
|
||||
init_discriminator: bool = True
|
||||
use_spectral_norm_disriminator: bool = False
|
||||
use_speaker_embedding: bool = False
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = 0
|
||||
detach_dp_input: bool = True
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
"""VITS TTS model
|
||||
|
||||
Paper::
|
||||
https://arxiv.org/pdf/2106.06103.pdf
|
||||
|
||||
Paper Abstract::
|
||||
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
|
||||
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
|
||||
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
|
||||
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
|
||||
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
|
||||
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
|
||||
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
|
||||
natural one-to-many relationship in which a text input can be spoken in multiple ways
|
||||
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
|
||||
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
|
||||
available TTS systems and achieves a MOS comparable to ground truth.
|
||||
|
||||
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import VitsConfig
|
||||
>>> from TTS.tts.models.vits import Vits
|
||||
>>> config = VitsConfig()
|
||||
>>> model = Vits(config)
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.END2END = True
|
||||
|
||||
if config.__class__.__name__ == "VitsConfig":
|
||||
# loading from VitsConfig
|
||||
if "num_chars" not in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
config.model_args.num_chars = num_chars
|
||||
else:
|
||||
self.config = config
|
||||
config.model_args.num_chars = config.num_chars
|
||||
args = self.config.model_args
|
||||
elif isinstance(config, VitsArgs):
|
||||
# loading from VitsArgs
|
||||
self.config = config
|
||||
args = config
|
||||
else:
|
||||
raise ValueError("config must be either a VitsConfig or VitsArgs")
|
||||
|
||||
self.args = args
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.length_scale = args.length_scale
|
||||
self.noise_scale = args.noise_scale
|
||||
self.inference_noise_scale = args.inference_noise_scale
|
||||
self.inference_noise_scale_dp = args.inference_noise_scale_dp
|
||||
self.noise_scale_dp = args.noise_scale_dp
|
||||
self.max_inference_len = args.max_inference_len
|
||||
self.spec_segment_size = args.spec_segment_size
|
||||
|
||||
self.text_encoder = TextEncoder(
|
||||
args.num_chars,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels_ffn_text_encoder,
|
||||
args.num_heads_text_encoder,
|
||||
args.num_layers_text_encoder,
|
||||
args.kernel_size_text_encoder,
|
||||
args.dropout_p_text_encoder,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
args.out_channels,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
kernel_size=args.kernel_size_posterior_encoder,
|
||||
dilation_rate=args.dilation_rate_posterior_encoder,
|
||||
num_layers=args.num_layers_posterior_encoder,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
self.flow = ResidualCouplingBlocks(
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
kernel_size=args.kernel_size_flow,
|
||||
dilation_rate=args.dilation_rate_flow,
|
||||
num_layers=args.num_layers_flow,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
if args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
args.hidden_channels,
|
||||
192,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
4,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
args.hidden_channels,
|
||||
1,
|
||||
args.resblock_type_decoder,
|
||||
args.resblock_dilation_sizes_decoder,
|
||||
args.resblock_kernel_sizes_decoder,
|
||||
args.upsample_kernel_sizes_decoder,
|
||||
args.upsample_initial_channel_decoder,
|
||||
args.upsample_rates_decoder,
|
||||
inference_padding=0,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
conv_pre_weight_norm=False,
|
||||
conv_post_weight_norm=False,
|
||||
conv_post_bias=False,
|
||||
)
|
||||
|
||||
if args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
|
||||
self.speaker_manager.num_speakers = config.num_speakers
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
|
||||
# init d-vector usage
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g = None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = aux_input["d_vectors"]
|
||||
return sid, g
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.tensor,
|
||||
x_lengths: torch.tensor,
|
||||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None},
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (torch.tensor): Batch of input character sequence IDs.
|
||||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||
y (torch.tensor): Batch of input spectrograms.
|
||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs keyed by the output name.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- y: :math:`[B, C, T_spec]`
|
||||
- y_lengths: :math:`[B]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1 and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
if self.args.use_sdp:
|
||||
nll_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
|
||||
outputs["nll_duration"] = nll_duration
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"slice_ids": slice_ids,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"m_q": m_q,
|
||||
"logs_q": logs_q,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
if self.num_speakers > 0 and sid:
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||
return outputs
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
"""TODO: create an end-point for voice conversion"""
|
||||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
waveform = batch["waveform"]
|
||||
|
||||
# generator pass
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
self.y_disc_cache = None
|
||||
self.wav_seg_disc_cache = None
|
||||
self.y_disc_cache = outputs["model_outputs"]
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
outputs["slice_ids"] * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
self.wav_seg_disc_cache = wav_seg
|
||||
outputs["waveform_seg"] = wav_seg
|
||||
|
||||
# compute discriminator scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
|
||||
_, outputs["feats_disc_real"] = self.disc(wav_seg)
|
||||
|
||||
# compute losses
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
waveform_hat=outputs["model_outputs"].float(),
|
||||
waveform=wav_seg.float(),
|
||||
z_p=outputs["z_p"].float(),
|
||||
logs_q=outputs["logs_q"].float(),
|
||||
m_p=outputs["m_p"].float(),
|
||||
logs_p=outputs["logs_p"].float(),
|
||||
z_len=mel_lengths,
|
||||
scores_disc_fake=outputs["scores_disc_fake"],
|
||||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
)
|
||||
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
outputs = {}
|
||||
|
||||
# compute scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
|
||||
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
outputs["scores_disc_real"],
|
||||
outputs["scores_disc_fake"],
|
||||
)
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
|
||||
): # pylint: disable=no-self-use
|
||||
"""Create visualizations and waveform examples.
|
||||
|
||||
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||
be projected onto Tensorboard.
|
||||
|
||||
Args:
|
||||
ap (AudioProcessor): audio processor used at training.
|
||||
batch (Dict): Model inputs used at the previous training step.
|
||||
outputs (Dict): Model outputs generated at the previoud training step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y = outputs[0]["waveform_seg"]
|
||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {f"{name_prefix}/audio": sample_voice}
|
||||
|
||||
alignments = outputs[0]["alignments"]
|
||||
align_img = alignments[0].data.cpu().numpy().T
|
||||
|
||||
figures.update(
|
||||
{
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
)
|
||||
|
||||
return figures, audios
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs, "eval")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
self.disc.requires_grad_(False)
|
||||
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
self.disc.requires_grad_(True)
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
)
|
||||
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||
return [optimizer1, optimizer2]
|
||||
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
||||
Returns:
|
||||
List: learning rates for each optimizer.
|
||||
"""
|
||||
return [self.config.lr_gen, self.config.lr_disc]
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the schedulers for each optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||
|
||||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler1, scheduler2]
|
||||
|
||||
def get_criterion(self):
|
||||
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
||||
`train_step()`"""
|
||||
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
|
||||
VitsDiscriminatorLoss,
|
||||
VitsGeneratorLoss,
|
||||
)
|
||||
|
||||
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def make_symbols(config):
|
||||
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
|
||||
whole training and inference steps."""
|
||||
_pad = config.characters["pad"]
|
||||
_punctuations = config.characters["punctuations"]
|
||||
_letters = config.characters["characters"]
|
||||
_letters_ipa = config.characters["phonemes"]
|
||||
symbols = [_pad] + list(_punctuations) + list(_letters)
|
||||
if config.use_phonemes:
|
||||
symbols += list(_letters_ipa)
|
||||
return symbols
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit):
|
||||
if config.characters is not None:
|
||||
symbols = Vits.make_symbols(config)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||
parse_symbols,
|
||||
phonemes,
|
||||
symbols,
|
||||
)
|
||||
|
||||
config.characters = parse_symbols()
|
||||
if config.use_phonemes:
|
||||
symbols = phonemes
|
||||
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||
return symbols, config, num_chars
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import importlib
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
|||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
|||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
|
|||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with open(output_path, "wb") as f:
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import random
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -84,12 +85,12 @@ class SpeakerManager:
|
|||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with open(json_file_path) as f:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with open(json_file_path, "w") as f:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
|
@ -294,9 +295,10 @@ def _set_file_path(path):
|
|||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||
path_continue = os.path.join(path, "speakers.json")
|
||||
if os.path.exists(path_restore):
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
if fs.exists(path_restore):
|
||||
return path_restore
|
||||
if os.path.exists(path_continue):
|
||||
if fs.exists(path_continue):
|
||||
return path_continue
|
||||
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||
|
||||
|
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
|
|||
json_file = out_path
|
||||
else:
|
||||
json_file = _set_file_path(out_path)
|
||||
with open(json_file) as f:
|
||||
with fsspec.open(json_file, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
|
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
|||
"""Saves speaker mapping if not yet present."""
|
||||
if out_path is not None:
|
||||
speakers_json_path = _set_file_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
with fsspec.open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
|
@ -358,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
elif c.use_d_vector_file and c.d_vector_file:
|
||||
# new speaker manager with external speaker embeddings.
|
||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
||||
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file.
|
||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
elif c.use_d_vector_file and not c.d_vector_file:
|
||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
||||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||
# new speaker manager with speaker IDs file.
|
||||
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||
print(
|
||||
" > Training with {} speakers: {}".format(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seq(text, CONFIG):
|
||||
def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
|
|||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
||||
custom_symbols=custom_symbols,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
else:
|
||||
seq = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
@ -229,13 +227,16 @@ def synthesis(
|
|||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
custom_symbols = None
|
||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
if hasattr(model, "make_symbols"):
|
||||
custom_symbols = model.make_symbols(CONFIG)
|
||||
# preprocess the given text
|
||||
text_inputs = text_to_seq(text, CONFIG)
|
||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
|
@ -274,15 +275,18 @@ def synthesis(
|
|||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
if hasattr(model, "END2END") and model.END2END:
|
||||
wav = model_outputs.squeeze(0)
|
||||
else:
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": alignments,
|
||||
"model_outputs": model_outputs,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
}
|
||||
return return_dict
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
# adapted from https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List
|
||||
|
||||
import gruut
|
||||
from packaging import version
|
||||
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||
|
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
|||
|
||||
_symbols = symbols
|
||||
_phonemes = phonemes
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
|||
# Fix a few phonemes
|
||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||
|
||||
print(" > Phonemes: {}".format(ph))
|
||||
return ph
|
||||
|
||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||
|
@ -106,13 +105,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
|
|||
|
||||
|
||||
def phoneme_to_sequence(
|
||||
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
|
||||
):
|
||||
text: str,
|
||||
cleaner_names: List[str],
|
||||
language: str,
|
||||
enable_eos_bos: bool = False,
|
||||
custom_symbols: List[str] = None,
|
||||
tp: Dict = None,
|
||||
add_blank: bool = False,
|
||||
use_espeak_phonemes: bool = False,
|
||||
) -> List[int]:
|
||||
"""Converts a string of phonemes to a sequence of IDs.
|
||||
If `custom_symbols` is provided, it will override the default symbols.
|
||||
|
||||
Args:
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
language (str): text language key for phonemization.
|
||||
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
|
||||
|
||||
Returns:
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _phonemes
|
||||
if tp:
|
||||
|
||||
if custom_symbols is not None:
|
||||
_phonemes = custom_symbols
|
||||
elif tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
sequence = []
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
|
@ -127,20 +151,22 @@ def phoneme_to_sequence(
|
|||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
||||
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
|
||||
# pylint: disable=global-statement
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
global _id_to_phonemes, _phonemes
|
||||
if add_blank:
|
||||
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
||||
result = ""
|
||||
if tp:
|
||||
|
||||
if custom_symbols is not None:
|
||||
_phonemes = custom_symbols
|
||||
elif tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_phonemes:
|
||||
|
@ -149,27 +175,32 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
|||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
||||
def text_to_sequence(
|
||||
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||
) -> List[int]:
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
If `custom_symbols` is provided, it will override the default symbols.
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
tp: dictionary of character parameters to use a custom character set.
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id, _symbols
|
||||
if tp:
|
||||
|
||||
if custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
elif tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while text:
|
||||
m = _CURLY_RE.match(text)
|
||||
|
@ -185,16 +216,18 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
|||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence, tp=None, add_blank=False):
|
||||
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
# pylint: disable=global-statement
|
||||
global _id_to_symbol, _symbols
|
||||
if add_blank:
|
||||
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
||||
|
||||
if tp:
|
||||
if custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
elif tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
|
|
|
@ -28,10 +28,10 @@ def make_symbols(
|
|||
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
||||
) # this is to keep previous models compatible.
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ["@" + s for s in _phonemes_sorted]
|
||||
# _arpabet = ["@" + s for s in _phonemes_sorted]
|
||||
# Export all symbols:
|
||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
_symbols += _arpabet
|
||||
# _symbols += _arpabet
|
||||
return _symbols, _phonemes
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
|
|||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
TODO: Merge this with audio.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
|
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
|
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
|
@ -87,6 +96,14 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class AudioProcessor(object):
|
||||
|
@ -97,33 +114,93 @@ class AudioProcessor(object):
|
|||
of the class with the model config. They are not meaningful for all the arguments.
|
||||
|
||||
Args:
|
||||
sample_rate (int, optional): target audio sampling rate. Defaults to None.
|
||||
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
|
||||
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
|
||||
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
|
||||
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
|
||||
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||
fft_size (int, optional): FFT window size for STFT. Defaults to 1024.
|
||||
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None.
|
||||
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None.
|
||||
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
|
||||
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20.
|
||||
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'.
|
||||
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None.
|
||||
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
|
||||
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False.
|
||||
stats_path (str, optional): Path to the computed stats file. Defaults to None.
|
||||
verbose (bool, optional): enable/disable logging. Defaults to True.
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
resample (bool, optional):
|
||||
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||
|
||||
num_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
log_func (int, optional):
|
||||
log exponent used for converting spectrogram aplitude to DB.
|
||||
|
||||
min_level_db (int, optional):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||
|
||||
frame_shift_ms (int, optional):
|
||||
milliseconds of frames between STFT columns. Defaults to None.
|
||||
|
||||
frame_length_ms (int, optional):
|
||||
milliseconds of STFT window length. Defaults to None.
|
||||
|
||||
hop_length (int, optional):
|
||||
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||
|
||||
ref_level_db (int, optional):
|
||||
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||
|
||||
fft_size (int, optional):
|
||||
FFT window size for STFT. Defaults to 1024.
|
||||
|
||||
power (int, optional):
|
||||
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||
|
||||
preemphasis (float, optional):
|
||||
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||
|
||||
signal_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to None.
|
||||
|
||||
symmetric_norm (bool, optional):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||
|
||||
max_norm (float, optional):
|
||||
```k``` defining the normalization range. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms.. Defaults to None.
|
||||
|
||||
spec_gain (int, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
stft_pad_mode (str, optional):
|
||||
Padding mode for STFT. Defaults to 'reflect'.
|
||||
|
||||
clip_norm (bool, optional):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
|
||||
griffin_lim_iters (int, optional):
|
||||
Number of GriffinLim iterations. Defaults to None.
|
||||
|
||||
do_trim_silence (bool, optional):
|
||||
enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||
|
||||
trim_db (int, optional):
|
||||
DB threshold used for silence trimming. Defaults to 60.
|
||||
|
||||
do_sound_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
stats_path (str, optional):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
|
||||
verbose (bool, optional):
|
||||
enable/disable logging. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -153,6 +230,8 @@ class AudioProcessor(object):
|
|||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
do_sound_norm=False,
|
||||
do_amp_to_db_linear=True,
|
||||
do_amp_to_db_mel=True,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_,
|
||||
|
@ -181,6 +260,8 @@ class AudioProcessor(object):
|
|||
self.do_trim_silence = do_trim_silence
|
||||
self.trim_db = trim_db
|
||||
self.do_sound_norm = do_sound_norm
|
||||
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||
self.stats_path = stats_path
|
||||
# setup exp_func for db to amp conversion
|
||||
if log_func == "np.log":
|
||||
|
@ -381,7 +462,6 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
|
||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
|
@ -448,7 +528,10 @@ class AudioProcessor(object):
|
|||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(np.abs(D))
|
||||
if self.do_amp_to_db_linear:
|
||||
S = self._amp_to_db(np.abs(D))
|
||||
else:
|
||||
S = np.abs(D)
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
|
@ -457,7 +540,10 @@ class AudioProcessor(object):
|
|||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||
if self.do_amp_to_db_mel:
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||
else:
|
||||
S = self._linear_to_mel(np.abs(D))
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -58,23 +57,22 @@ def get_commit_hash():
|
|||
return commit
|
||||
|
||||
|
||||
def create_experiment_folder(root_path, model_name):
|
||||
"""Create a folder with the current date and time"""
|
||||
def get_experiment_folder_path(root_path, model_name):
|
||||
"""Get an experiment folder path with the current date and time"""
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||
fs = fsspec.get_mapper(experiment_path).fs
|
||||
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
|
||||
if not checkpoint_files:
|
||||
if os.path.exists(experiment_path):
|
||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
||||
if fs.exists(experiment_path):
|
||||
fs.rm(experiment_path, recursive=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
from shutil import copyfile
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
@ -24,7 +26,7 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
def copy_model_files(config, out_path, new_fields):
|
||||
def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||
"""Copy config.json and other model files to training folder and add
|
||||
new fields.
|
||||
|
||||
|
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
|
|||
copy_config_path = os.path.join(out_path, "config.json")
|
||||
# add extra information fields
|
||||
config.update(new_fields, allow_new=True)
|
||||
config.save_json(copy_config_path)
|
||||
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
||||
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
||||
json.dump(config.to_dict(), f, indent=4)
|
||||
|
||||
# copy model stats file if available
|
||||
if config.audio.stats_path is not None:
|
||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||
if not os.path.exists(copy_stats_path):
|
||||
copyfile(
|
||||
config.audio.stats_path,
|
||||
copy_stats_path,
|
||||
)
|
||||
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
||||
if not filesystem.exists(copy_stats_path):
|
||||
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
||||
with fsspec.open(copy_stats_path, "wb") as target_file:
|
||||
shutil.copyfileobj(source_file, target_file)
|
||||
|
||||
|
||||
def load_fsspec(path: str, **kwargs) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
|
|||
return model, state
|
||||
|
||||
|
||||
def save_fsspec(state: Any, path: str, **kwargs):
|
||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
state: State object to save
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.save.
|
||||
"""
|
||||
with fsspec.open(path, "wb") as f:
|
||||
torch.save(state, f, **kwargs)
|
||||
|
||||
|
||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
|
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
|
|||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
save_fsspec(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
|
@ -147,18 +178,16 @@ def save_best_model(
|
|||
model_loss=current_loss,
|
||||
**kwargs,
|
||||
)
|
||||
fs = fsspec.get_mapper(out_path).fs
|
||||
# only delete previous if current is saved successfully
|
||||
if not keep_all_best or (current_step < keep_after):
|
||||
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||
for model_name in model_names:
|
||||
if os.path.basename(model_name) == best_model_name:
|
||||
continue
|
||||
os.remove(model_name)
|
||||
# create symlink to best model for convinience
|
||||
link_name = "best_model.pth.tar"
|
||||
link_path = os.path.join(out_path, link_name)
|
||||
if os.path.islink(link_path) or os.path.isfile(link_path):
|
||||
os.remove(link_path)
|
||||
os.symlink(best_model_name, os.path.join(out_path, link_name))
|
||||
if os.path.basename(model_name) != best_model_name:
|
||||
fs.rm(model_name)
|
||||
# create a shortcut which always points to the currently best model
|
||||
shortcut_name = "best_model.pth.tar"
|
||||
shortcut_path = os.path.join(out_path, shortcut_name)
|
||||
fs.copy(checkpoint_path, shortcut_path)
|
||||
best_loss = current_loss
|
||||
return best_loss
|
||||
|
|
|
@ -1,2 +1,24 @@
|
|||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||
|
||||
|
||||
def init_logger(config):
|
||||
if config.dashboard_logger == "tensorboard":
|
||||
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
||||
|
||||
elif config.dashboard_logger == "wandb":
|
||||
project_name = config.model
|
||||
if config.project_name:
|
||||
project_name = config.project_name
|
||||
|
||||
dashboard_logger = WandbLogger(
|
||||
project=project_name,
|
||||
name=config.run_name,
|
||||
config=config,
|
||||
entity=config.wandb_entity,
|
||||
)
|
||||
|
||||
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
||||
return dashboard_logger
|
||||
|
|
|
@ -38,7 +38,7 @@ class ConsoleLogger:
|
|||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict):
|
||||
def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
||||
|
@ -50,13 +50,6 @@ class ConsoleLogger:
|
|||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
for idx, (key, value) in enumerate(log_dict.items()):
|
||||
if isinstance(value, list):
|
||||
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
|
||||
else:
|
||||
log_text += f"{indent}{key}: {value}"
|
||||
if idx < len(log_dict) - 1:
|
||||
log_text += "\n"
|
||||
print(log_text, flush=True)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
|
|
@ -7,10 +7,8 @@ class TensorboardLogger(object):
|
|||
def __init__(self, log_dir, model_name):
|
||||
self.model_name = model_name
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
self.train_stats = {}
|
||||
self.eval_stats = {}
|
||||
|
||||
def tb_model_weights(self, model, step):
|
||||
def model_weights(self, model, step):
|
||||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.numel() == 1:
|
||||
|
@ -41,32 +39,41 @@ class TensorboardLogger(object):
|
|||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
def tb_train_step_stats(self, step, stats):
|
||||
def train_step_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||
|
||||
def tb_train_epoch_stats(self, step, stats):
|
||||
def train_epoch_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
|
||||
|
||||
def tb_train_figures(self, step, figures):
|
||||
def train_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
|
||||
|
||||
def tb_train_audios(self, step, audios, sample_rate):
|
||||
def train_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_eval_stats(self, step, stats):
|
||||
def eval_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
|
||||
|
||||
def tb_eval_figures(self, step, figures):
|
||||
def eval_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
|
||||
|
||||
def tb_eval_audios(self, step, audios, sample_rate):
|
||||
def eval_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_test_audios(self, step, audios, sample_rate):
|
||||
def test_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_test_figures(self, step, figures):
|
||||
def test_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
|
||||
|
||||
def tb_add_text(self, title, text, step):
|
||||
def add_text(self, title, text, step):
|
||||
self.writer.add_text(title, text, step)
|
||||
|
||||
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
|
||||
yield
|
||||
|
||||
def flush(self):
|
||||
self.writer.flush()
|
||||
|
||||
def finish(self):
|
||||
self.writer.close()
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# pylint: disable=W0613
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import wandb
|
||||
from wandb import finish, init # pylint: disable=W0611
|
||||
except ImportError:
|
||||
wandb = None
|
||||
|
||||
|
||||
class WandbLogger:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
if not wandb:
|
||||
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
|
||||
|
||||
self.run = None
|
||||
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||
self.model_name = self.run.config.model
|
||||
self.log_dict = {}
|
||||
|
||||
def model_weights(self, model):
|
||||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.numel() == 1:
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()})
|
||||
else:
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
||||
self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
|
||||
self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
|
||||
layer_num += 1
|
||||
|
||||
def dict_to_scalar(self, scope_name, stats):
|
||||
for key, value in stats.items():
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = value
|
||||
|
||||
def dict_to_figure(self, scope_name, figures):
|
||||
for key, value in figures.items():
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
|
||||
|
||||
def dict_to_audios(self, scope_name, audios, sample_rate):
|
||||
for key, value in audios.items():
|
||||
if value.dtype == "float16":
|
||||
value = value.astype("float32")
|
||||
try:
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
|
||||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
def log(self, log_dict, prefix="", flush=False):
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = value
|
||||
if flush: # for cases where you don't want to accumulate data
|
||||
self.flush()
|
||||
|
||||
def train_step_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
|
||||
|
||||
def train_epoch_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
|
||||
|
||||
def train_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
|
||||
|
||||
def train_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
|
||||
|
||||
def eval_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
|
||||
|
||||
def eval_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
|
||||
|
||||
def eval_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
|
||||
|
||||
def test_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
|
||||
|
||||
def test_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
|
||||
|
||||
def add_text(self, title, text, step):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
if self.run:
|
||||
wandb.log(self.log_dict)
|
||||
self.log_dict = {}
|
||||
|
||||
def finish(self):
|
||||
if self.run:
|
||||
self.run.finish()
|
||||
|
||||
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
||||
if not self.run:
|
||||
return
|
||||
name = "_".join([self.run.id, name])
|
||||
artifact = wandb.Artifact(name, type=artifact_type)
|
||||
data_path = Path(file_or_dir)
|
||||
if data_path.is_dir():
|
||||
artifact.add_dir(str(data_path))
|
||||
elif data_path.is_file():
|
||||
artifact.add_file(str(data_path))
|
||||
|
||||
self.run.log_artifact(artifact, aliases=aliases)
|
|
@ -64,6 +64,7 @@ class ModelManager(object):
|
|||
def list_models(self):
|
||||
print(" Name format: type/language/dataset/model")
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
|
@ -71,10 +72,11 @@ class ModelManager(object):
|
|||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" >: {model_type}/{lang}/{dataset}/{model}")
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return models_name_list
|
||||
|
||||
def download_model(self, model_name):
|
||||
|
|
|
@ -12,7 +12,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
|||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||
from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
@ -103,6 +102,34 @@ class Synthesizer(object):
|
|||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
||||
|
||||
def _set_tts_speaker_file(self):
|
||||
"""Set the TTS speaker file used by a multi-speaker model."""
|
||||
# setup if multi-speaker settings are in the global model config
|
||||
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
||||
if self.tts_config.use_d_vector_file:
|
||||
self.tts_speakers_file = (
|
||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
||||
)
|
||||
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
||||
else:
|
||||
self.tts_speakers_file = (
|
||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"]
|
||||
)
|
||||
|
||||
# setup if multi-speaker settings are in the model args config
|
||||
if (
|
||||
self.tts_speakers_file is None
|
||||
and hasattr(self.tts_config, "model_args")
|
||||
and hasattr(self.tts_config.model_args, "use_speaker_embedding")
|
||||
and self.tts_config.model_args.use_speaker_embedding
|
||||
):
|
||||
_args = self.tts_config.model_args
|
||||
if _args.use_d_vector_file:
|
||||
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"]
|
||||
_args["d_vector_file"] = self.tts_speakers_file
|
||||
else:
|
||||
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"]
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
||||
|
@ -113,29 +140,15 @@ class Synthesizer(object):
|
|||
"""
|
||||
# pylint: disable=global-statement
|
||||
|
||||
global symbols, phonemes
|
||||
self.tts_config = load_config(tts_config_path)
|
||||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||
|
||||
if self.tts_config.has("characters") and self.tts_config.characters:
|
||||
symbols, phonemes = make_symbols(**self.tts_config.characters)
|
||||
|
||||
if self.use_phonemes:
|
||||
self.input_size = len(phonemes)
|
||||
else:
|
||||
self.input_size = len(symbols)
|
||||
|
||||
if self.tts_config.use_speaker_embedding is True:
|
||||
self.tts_speakers_file = (
|
||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
||||
)
|
||||
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
||||
|
||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
self._set_tts_speaker_file()
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
"""Load the vocoder model.
|
||||
|
@ -187,15 +200,22 @@ class Synthesizer(object):
|
|||
"""
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
speaker_embedding = None
|
||||
sens = self.split_into_sentences(text)
|
||||
print(" > Text splitted to sentences.")
|
||||
print(sens)
|
||||
|
||||
# handle multi-speaker
|
||||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
if speaker_idx and isinstance(speaker_idx, str):
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
||||
|
||||
elif not speaker_idx and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-speaker model. "
|
||||
|
@ -224,14 +244,14 @@ class Synthesizer(object):
|
|||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
ap=self.ap,
|
||||
speaker_id=None,
|
||||
speaker_id=speaker_id,
|
||||
style_wav=style_wav,
|
||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||
use_griffin_lim=use_gl,
|
||||
d_vector=speaker_embedding,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["model_outputs"]
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import importlib
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -48,7 +48,7 @@ def get_scheduler(
|
|||
|
||||
|
||||
def get_optimizer(
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Find, initialize and return a optimizer.
|
||||
|
||||
|
@ -66,4 +66,6 @@ def get_optimizer(
|
|||
optimizer = getattr(module, "RAdam")
|
||||
else:
|
||||
optimizer = getattr(torch.optim, optimizer_name)
|
||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
||||
if model is not None:
|
||||
parameters = model.parameters()
|
||||
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||
|
|
|
@ -24,8 +24,10 @@ def setup_model(config: Coqpit):
|
|||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
raise ValueError(f"Model {config.model} not exist!")
|
||||
try:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Model {config.model} not exist!") from e
|
||||
model = MyModel(config)
|
||||
return model
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
|
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
|
|||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
|
|
|
@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
|
|||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
|
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[List[Tensor]]: list of scores from each discriminator.
|
||||
[List[Tensor]]: list of scores from each discriminator.
|
||||
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
|
|
|
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
|||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -168,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
|
@ -216,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
|
|||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if not conv_pre_weight_norm:
|
||||
remove_weight_norm(self.conv_pre)
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
@ -231,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
|
|||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
|
@ -275,7 +292,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
|
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
|
|||
class UnivnetGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
cond_channels,
|
||||
upsample_factors,
|
||||
lvc_layers_each_block,
|
||||
lvc_kernel_size,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
dropout,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
cond_channels: int,
|
||||
upsample_factors: List[int],
|
||||
lvc_layers_each_block: int,
|
||||
lvc_kernel_size: int,
|
||||
kpnet_hidden_channels: int,
|
||||
kpnet_conv_size: int,
|
||||
dropout: float,
|
||||
use_weight_norm=True,
|
||||
):
|
||||
"""Univnet Generator network.
|
||||
|
||||
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
out_channels (int): Number of channels of the output tensor.
|
||||
hidden_channels (int): Number of hidden network channels.
|
||||
cond_channels (int): Number of channels of the conditioning tensors.
|
||||
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||
dropout (float): Dropout rate.
|
||||
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
|
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
|
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
|
|||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
"""Load TF Vocoder model"""
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
|
|||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with open(output_path, "wb") as f:
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
|
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
|
|||
return spec
|
||||
|
||||
|
||||
def plot_results(y_hat, y, ap, name_prefix):
|
||||
"""Plot vocoder model results"""
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||
"""Plot the predicted and the real waveform and their spectrograms.
|
||||
|
||||
Args:
|
||||
y_hat (torch.tensor): Predicted waveform.
|
||||
y (torch.tensor): Real waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: output figures keyed by the name of the figures.
|
||||
""" """Plot vocoder model results"""
|
||||
if name_prefix is None:
|
||||
name_prefix = ""
|
||||
|
||||
# select an instance from batch
|
||||
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
y = y[0].squeeze(0).detach().cpu().numpy()
|
||||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||
y = y[0].squeeze().detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
|
|
|
@ -2,4 +2,5 @@ furo
|
|||
myst-parser == 0.15.1
|
||||
sphinx == 4.0.2
|
||||
sphinx_inline_tabs
|
||||
sphinx_copybutton
|
||||
sphinx_copybutton
|
||||
linkify-it-py
|
|
@ -68,6 +68,8 @@ extensions = [
|
|||
"sphinx_inline_tabs",
|
||||
]
|
||||
|
||||
myst_enable_extensions = ['linkify',]
|
||||
|
||||
# 'sphinxcontrib.katex',
|
||||
# 'sphinx.ext.autosectionlabel',
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
:caption: `tts` Models
|
||||
|
||||
models/glow_tts.md
|
||||
models/vits.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Glow TTS
|
||||
|
||||
Glow TTS is a normalizing flow model for text-to-speech. It is built on the generic Glow model that is previously
|
||||
used in computer vision and vocoder models. It uses "monotonic alignment search" (MAS) to fine the text-to-speech alignment
|
||||
and uses the output to train a separate duration predictor network for faster inference run-time.
|
||||
|
||||
## Important resources & papers
|
||||
- GlowTTS: https://arxiv.org/abs/2005.11129
|
||||
- Glow (Generative Flow with invertible 1x1 Convolutions): https://arxiv.org/abs/1807.03039
|
||||
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
|
||||
|
||||
## GlowTTS Config
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.configs.glow_tts_config.GlowTTSConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
## GlowTTS Model
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.glow_tts.GlowTTS
|
||||
:members:
|
||||
```
|
|
@ -0,0 +1,33 @@
|
|||
# VITS
|
||||
|
||||
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
|
||||
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
|
||||
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
|
||||
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
||||
It is a feed-forward model with x67.12 real-time factor on a GPU.
|
||||
|
||||
## Important resources & papers
|
||||
- VITS: https://arxiv.org/pdf/2106.06103.pdf
|
||||
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
|
||||
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
|
||||
- Generative Adversarial Networks: https://arxiv.org/abs/1406.2661
|
||||
- HiFiGAN: https://arxiv.org/abs/2010.05646
|
||||
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
|
||||
|
||||
## VitsConfig
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.configs.vits_config.VitsConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
## VitsArgs
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.vits.VitsArgs
|
||||
:members:
|
||||
```
|
||||
|
||||
## Vits Model
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.vits.Vits
|
||||
:members:
|
||||
```
|
|
@ -85,6 +85,7 @@ We still support running training from CLI like in the old days. The same traini
|
|||
|
||||
```json
|
||||
{
|
||||
"run_name": "my_run",
|
||||
"model": "glow_tts",
|
||||
"batch_size": 32,
|
||||
"eval_batch_size": 16,
|
||||
|
|
|
@ -25,9 +25,7 @@
|
|||
"import umap\n",
|
||||
"\n",
|
||||
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
|
||||
"from TTS.utils.audio import AudioProcessor
|
||||
|
||||
\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.tts.utils.generic_utils import load_config\n",
|
||||
"\n",
|
||||
"from bokeh.io import output_notebook, show\n",
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
|
||||
from TTS.tts.configs import AlignTTSConfig
|
||||
from TTS.tts.configs import BaseDatasetConfig
|
||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
config = AlignTTSConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
|
@ -23,8 +23,8 @@ config = AlignTTSConfig(
|
|||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.configs import BaseDatasetConfig
|
||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
config = GlowTTSConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
|
@ -23,8 +23,8 @@ config = GlowTTSConfig(
|
|||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,6 +24,6 @@ config = HifiganConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import MultibandMelganConfig
|
||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
||||
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = MultibandMelganConfig(
|
||||
|
@ -25,6 +24,6 @@ config = MultibandMelganConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import UnivnetConfig
|
||||
|
||||
|
@ -25,6 +24,6 @@ config = UnivnetConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=45,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
)
|
||||
config = VitsConfig(
|
||||
audio=audio_config,
|
||||
run_name="vits_ljspeech",
|
||||
batch_size=48,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=0,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
compute_input_seq_cache=True,
|
||||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=True,
|
||||
max_seq_len=5000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
||||
trainer.fit()
|
|
@ -1,10 +1,8 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
from TTS.trainer import TrainingArgs
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import WavegradConfig
|
||||
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = WavegradConfig(
|
||||
batch_size=32,
|
||||
|
@ -24,6 +22,6 @@ config = WavegradConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, init_training, TrainingArgs
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import WavernnConfig
|
||||
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = WavernnConfig(
|
||||
batch_size=64,
|
||||
|
@ -25,6 +24,6 @@ config = WavernnConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,3 +24,4 @@ mecab-python3==1.0.3
|
|||
unidic-lite==1.0.8
|
||||
# gruut+supported langs
|
||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
||||
fsspec>=2021.04.0
|
||||
|
|
|
@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True,
|
||||
return_wav=True,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
characters=c.characters,
|
||||
|
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
|
|||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
wavs = data[11]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert isinstance(speaker_name[0], str)
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.audio["num_mels"]
|
||||
assert (
|
||||
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
|
||||
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
|
||||
|
||||
# make sure that the computed mels and the waveform match and correctly computed
|
||||
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
||||
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
||||
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
|
||||
assert abs(mel_diff.sum()) < 1e-5
|
||||
|
||||
# check normalization ranges
|
||||
if self.ap.symmetric_norm:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
|
|
|
@ -27,6 +27,7 @@ config = AlignTTSConfig(
|
|||
"Be a voice, not an echo.",
|
||||
],
|
||||
)
|
||||
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
|
|
@ -29,6 +29,7 @@ config = Tacotron2Config(
|
|||
"Be a voice, not an echo.",
|
||||
],
|
||||
d_vector_file="tests/data/ljspeech/speakers.json",
|
||||
d_vector_dim=256,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
|
||||
|
|
|
@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
|||
|
||||
|
||||
class TacotronTrainTest(unittest.TestCase):
|
||||
"""Test vanilla Tacotron2 model."""
|
||||
|
||||
def test_train_step(self): # pylint: disable=no-self-use
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = False
|
||||
config.num_speakers = 1
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[0] = 30
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron2(config).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
for i in range(5):
|
||||
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class MultiSpeakerTacotronTrainTest(unittest.TestCase):
|
||||
"""Test multi-speaker Tacotron2 with speaker embedding layer"""
|
||||
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = True
|
||||
config.num_speakers = 5
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
|
@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
config.d_vector_dim = 55
|
||||
model = Tacotron2(config).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
|
@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
|
||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[0] = 30
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.rand(8, 55).to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
config.d_vector_dim = 55
|
||||
model = Tacotron2(config).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
for i in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class TacotronGSTTrainTest(unittest.TestCase):
|
||||
"""Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def test_train_step(self):
|
||||
# with random gst mel style
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = True
|
||||
config.num_speakers = 10
|
||||
config.use_gst = True
|
||||
config.gst = GSTConfig()
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
|
@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
|
||||
|
||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
|
||||
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
|
||||
config = config_global.copy()
|
||||
config.use_d_vector_file = True
|
||||
|
||||
config.use_gst = True
|
||||
config.gst = GSTConfig()
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
config = Tacotron2Config(
|
||||
r=5,
|
||||
batch_size=8,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=False,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
print_eval=True,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
|
||||
f"--coqpit.output_path file://{output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.test_delay_epochs 0 "
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
|
||||
)
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = False
|
||||
config.num_speakers = 1
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[-1] = mel_spec.size(1)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||
model.train()
|
||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
for _ in range(5):
|
||||
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = True
|
||||
config.num_speakers = 5
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
|
@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
config.d_vector_dim = 55
|
||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||
model.train()
|
||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||
|
@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
|
||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[-1] = mel_spec.size(1)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
config.d_vector_dim = 55
|
||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||
model.train()
|
||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
for _ in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class TacotronGSTTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
config.use_speaker_embedding = True
|
||||
config.num_speakers = 10
|
||||
config.use_gst = True
|
||||
config.gst = GSTConfig()
|
||||
# with random gst mel style
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
|
@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
@staticmethod
|
||||
def test_train_step():
|
||||
config = config_global.copy()
|
||||
config.use_d_vector_file = True
|
||||
|
||||
config.use_gst = True
|
||||
config.gst = GSTConfig()
|
||||
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue