mirror of https://github.com/coqui-ai/TTS.git
commit
01a2b0b5c0
|
@ -1,10 +1,15 @@
|
||||||
---
|
# Pull request guidelines
|
||||||
name: 'Contribution Guideline '
|
|
||||||
about: Refer to Contirbution Guideline
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support!
|
||||||
|
|
||||||
👐 Please check our [CONTRIBUTION GUIDELINE](https://github.com/coqui-ai/TTS#contribution-guidelines).
|
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
|
||||||
|
|
||||||
|
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
|
||||||
|
|
||||||
|
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS).
|
||||||
|
|
||||||
|
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS):
|
||||||
|
|
||||||
|
- Protects you, Coqui, and the users of the code.
|
||||||
|
- Does not change your rights to use your contributions for any purpose.
|
||||||
|
- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.
|
||||||
|
|
|
@ -136,6 +136,7 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
|
||||||
temp_build/*
|
temp_build/*
|
||||||
recipes/WIP/*
|
recipes/WIP/*
|
||||||
recipes/ljspeech/LJSpeech-1.1/*
|
recipes/ljspeech/LJSpeech-1.1/*
|
||||||
|
recipes/ljspeech/tacotron2-DDC/LJSpeech-1.1/*
|
||||||
events.out*
|
events.out*
|
||||||
old_configs/*
|
old_configs/*
|
||||||
model_importers/*
|
model_importers/*
|
||||||
|
@ -152,4 +153,6 @@ output.wav
|
||||||
tts_output.wav
|
tts_output.wav
|
||||||
deps.json
|
deps.json
|
||||||
speakers.json
|
speakers.json
|
||||||
internal/*
|
internal/*
|
||||||
|
*_pitch.npy
|
||||||
|
*_phoneme.npy
|
2
Makefile
2
Makefile
|
@ -4,7 +4,7 @@
|
||||||
help:
|
help:
|
||||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
target_dirs := tests TTS notebooks
|
target_dirs := tests TTS notebooks recipes
|
||||||
|
|
||||||
test_all: ## run tests and don't stop on an error.
|
test_all: ## run tests and don't stop on an error.
|
||||||
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||||
|
|
|
@ -73,10 +73,13 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
||||||
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
|
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
|
||||||
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
|
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
|
||||||
|
|
||||||
|
### End-to-End Models
|
||||||
|
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
|
||||||
|
|
||||||
### Attention Methods
|
### Attention Methods
|
||||||
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
|
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
|
||||||
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
|
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
|
||||||
- Graves Attention: [paper](https://arxiv.org/abs/1907.09006)
|
- Graves Attention: [paper](https://arxiv.org/abs/1910.10288)
|
||||||
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
|
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
|
||||||
- Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf)
|
- Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf)
|
||||||
|
|
||||||
|
|
179
TTS/.models.json
179
TTS/.models.json
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"tts_models":{
|
"tts_models": {
|
||||||
"en":{
|
"en": {
|
||||||
"ek1":{
|
"ek1": {
|
||||||
"tacotron2": {
|
"tacotron2": {
|
||||||
"description": "EK1 en-rp tacotron2 by NMStoker",
|
"description": "EK1 en-rp tacotron2 by NMStoker",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
|
||||||
|
@ -9,7 +9,7 @@
|
||||||
"commit": "c802255"
|
"commit": "c802255"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ljspeech":{
|
"ljspeech": {
|
||||||
"tacotron2-DDC": {
|
"tacotron2-DDC": {
|
||||||
"description": "Tacotron2 with Double Decoder Consistency.",
|
"description": "Tacotron2 with Double Decoder Consistency.",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip",
|
||||||
|
@ -17,9 +17,18 @@
|
||||||
"commit": "bae2ad0f",
|
"commit": "bae2ad0f",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "",
|
"license": "",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
},
|
},
|
||||||
"glow-tts":{
|
"tacotron2-DDC_ph": {
|
||||||
|
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
|
||||||
|
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
|
||||||
|
"commit": "3900448",
|
||||||
|
"author": "Eren Gölge @erogol",
|
||||||
|
"license": "",
|
||||||
|
"contact": "egolge@coqui.com"
|
||||||
|
},
|
||||||
|
"glow-tts": {
|
||||||
"description": "",
|
"description": "",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip",
|
||||||
"stats_file": null,
|
"stats_file": null,
|
||||||
|
@ -27,7 +36,7 @@
|
||||||
"commit": "",
|
"commit": "",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
},
|
},
|
||||||
"tacotron2-DCA": {
|
"tacotron2-DCA": {
|
||||||
"description": "",
|
"description": "",
|
||||||
|
@ -36,19 +45,28 @@
|
||||||
"commit": "",
|
"commit": "",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
},
|
},
|
||||||
"speedy-speech-wn":{
|
"speedy-speech-wn": {
|
||||||
"description": "Speedy Speech model with wavenet decoder.",
|
"description": "Speedy Speech model with wavenet decoder.",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
|
||||||
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
|
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
|
||||||
"commit": "77b6145",
|
"commit": "77b6145",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
|
},
|
||||||
|
"vits": {
|
||||||
|
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",
|
||||||
|
"default_vocoder": null,
|
||||||
|
"commit": "3900448",
|
||||||
|
"author": "Eren Gölge @erogol",
|
||||||
|
"license": "",
|
||||||
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"vctk":{
|
"vctk": {
|
||||||
"sc-glow-tts": {
|
"sc-glow-tts": {
|
||||||
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
|
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
|
||||||
|
@ -56,12 +74,19 @@
|
||||||
"commit": "b531fa69",
|
"commit": "b531fa69",
|
||||||
"author": "Edresson Casanova",
|
"author": "Edresson Casanova",
|
||||||
"license": "",
|
"license": "",
|
||||||
"contact":""
|
"contact": ""
|
||||||
|
},
|
||||||
|
"vits": {
|
||||||
|
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--vctk--vits.zip",
|
||||||
|
"default_vocoder": null,
|
||||||
|
"commit": "3900448",
|
||||||
|
"author": "Eren @erogol",
|
||||||
|
"license": "",
|
||||||
|
"contact": "egolge@coqui.ai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sam":{
|
"sam": {
|
||||||
"tacotron-DDC": {
|
"tacotron-DDC": {
|
||||||
"description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.",
|
"description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip",
|
||||||
|
@ -69,46 +94,47 @@
|
||||||
"commit": "bae2ad0f",
|
"commit": "bae2ad0f",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "",
|
"license": "",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"es":{
|
"es": {
|
||||||
"mai":{
|
"mai": {
|
||||||
"tacotron2-DDC":{
|
"tacotron2-DDC": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip",
|
||||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||||
"commit": "",
|
"commit": "",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fr":{
|
"fr": {
|
||||||
"mai":{
|
"mai": {
|
||||||
"tacotron2-DDC":{
|
"tacotron2-DDC": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip",
|
||||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||||
"commit": "",
|
"commit": "",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"zh-CN":{
|
"zh-CN": {
|
||||||
"baker":{
|
"baker": {
|
||||||
"tacotron2-DDC-GST":{
|
"tacotron2-DDC-GST": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
|
||||||
"commit": "unknown",
|
"commit": "unknown",
|
||||||
"author": "@kirianguiller"
|
"author": "@kirianguiller",
|
||||||
|
"default_vocoder": null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nl":{
|
"nl": {
|
||||||
"mai":{
|
"mai": {
|
||||||
"tacotron2-DDC":{
|
"tacotron2-DDC": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip",
|
||||||
"author": "@r-dh",
|
"author": "@r-dh",
|
||||||
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
|
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
|
||||||
|
@ -117,20 +143,9 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ru":{
|
"de": {
|
||||||
"ruslan":{
|
"thorsten": {
|
||||||
"tacotron2-DDC":{
|
"tacotron2-DCA": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--ru--ruslan--tacotron2-DDC.zip",
|
|
||||||
"author": "@erogol",
|
|
||||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
|
||||||
"license":"",
|
|
||||||
"contact": "egolge@coqui.com"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"de":{
|
|
||||||
"thorsten":{
|
|
||||||
"tacotron2-DCA":{
|
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
|
||||||
"default_vocoder": "vocoder_models/de/thorsten/fullband-melgan",
|
"default_vocoder": "vocoder_models/de/thorsten/fullband-melgan",
|
||||||
"author": "@thorstenMueller",
|
"author": "@thorstenMueller",
|
||||||
|
@ -138,9 +153,9 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ja":{
|
"ja": {
|
||||||
"kokoro":{
|
"kokoro": {
|
||||||
"tacotron2-DDC":{
|
"tacotron2-DDC": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip",
|
||||||
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
|
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
|
||||||
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
|
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
|
||||||
|
@ -150,54 +165,62 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"vocoder_models":{
|
"vocoder_models": {
|
||||||
"universal":{
|
"universal": {
|
||||||
"libri-tts":{
|
"libri-tts": {
|
||||||
"wavegrad":{
|
"wavegrad": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip",
|
||||||
"commit": "ea976b0",
|
"commit": "ea976b0",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
},
|
},
|
||||||
"fullband-melgan":{
|
"fullband-melgan": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip",
|
||||||
"commit": "4132240",
|
"commit": "4132240",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"en": {
|
"en": {
|
||||||
"ek1":{
|
"ek1": {
|
||||||
"wavegrad": {
|
"wavegrad": {
|
||||||
"description": "EK1 en-rp wavegrad by NMStoker",
|
"description": "EK1 en-rp wavegrad by NMStoker",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip",
|
||||||
"commit": "c802255"
|
"commit": "c802255"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ljspeech":{
|
"ljspeech": {
|
||||||
"multiband-melgan":{
|
"multiband-melgan": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip",
|
||||||
"commit": "ea976b0",
|
"commit": "ea976b0",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "MPL",
|
"license": "MPL",
|
||||||
"contact":"egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
},
|
},
|
||||||
"hifigan_v2":{
|
"hifigan_v2": {
|
||||||
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
|
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip",
|
||||||
"commit": "bae2ad0f",
|
"commit": "bae2ad0f",
|
||||||
"author": "@erogol",
|
"author": "@erogol",
|
||||||
"license": "",
|
"license": "",
|
||||||
"contact": "egolge@coqui.ai"
|
"contact": "egolge@coqui.ai"
|
||||||
|
},
|
||||||
|
"univnet": {
|
||||||
|
"description": "UnivNet model trained on LJSpeech to complement the TacotronDDC_ph model.",
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--en--ljspeech--univnet.zip",
|
||||||
|
"commit": "3900448",
|
||||||
|
"author": "Eren @erogol",
|
||||||
|
"license": "",
|
||||||
|
"contact": "egolge@coqui.ai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"vctk":{
|
"vctk": {
|
||||||
"hifigan_v2":{
|
"hifigan_v2": {
|
||||||
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
|
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
|
||||||
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
|
||||||
"commit": "2f07160",
|
"commit": "2f07160",
|
||||||
"author": "Edresson Casanova",
|
"author": "Edresson Casanova",
|
||||||
"license": "",
|
"license": "",
|
||||||
|
@ -205,9 +228,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sam": {
|
"sam": {
|
||||||
"hifigan_v2":{
|
"hifigan_v2": {
|
||||||
"description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC",
|
"description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC",
|
||||||
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
|
||||||
"commit": "2f07160",
|
"commit": "2f07160",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "",
|
"license": "",
|
||||||
|
@ -215,28 +238,38 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nl":{
|
"nl": {
|
||||||
"mai":{
|
"mai": {
|
||||||
"parallel-wavegan":{
|
"parallel-wavegan": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip",
|
||||||
"author": "@r-dh",
|
"author": "@r-dh",
|
||||||
"commit": "unknown"
|
"commit": "unknown"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"de":{
|
"de": {
|
||||||
"thorsten":{
|
"thorsten": {
|
||||||
"wavegrad":{
|
"wavegrad": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip",
|
||||||
"author": "@thorstenMueller",
|
"author": "@thorstenMueller",
|
||||||
"commit": "unknown"
|
"commit": "unknown"
|
||||||
},
|
},
|
||||||
"fullband-melgan":{
|
"fullband-melgan": {
|
||||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip",
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip",
|
||||||
"author": "@thorstenMueller",
|
"author": "@thorstenMueller",
|
||||||
"commit": "unknown"
|
"commit": "unknown"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"ja": {
|
||||||
|
"kokoro": {
|
||||||
|
"hifigan_v1": {
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--ja--kokoro--hifigan_v1.zip",
|
||||||
|
"description": "HifiGAN model trained for kokoro dataset by @kaiidams",
|
||||||
|
"author": "@kaiidams",
|
||||||
|
"commit": "3900448"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config, load_fsspec
|
||||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||||
compare_torch_tf,
|
compare_torch_tf,
|
||||||
convert_tf_name,
|
convert_tf_name,
|
||||||
|
@ -33,7 +33,7 @@ num_speakers = 0
|
||||||
|
|
||||||
# init torch model
|
# init torch model
|
||||||
model = setup_generator(c)
|
model = setup_generator(c)
|
||||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.remove_weight_norm()
|
model.remove_weight_norm()
|
||||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config, load_fsspec
|
||||||
|
|
||||||
sys.path.append("/home/erogol/Projects")
|
sys.path.append("/home/erogol/Projects")
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
@ -32,7 +32,7 @@ num_speakers = 0
|
||||||
|
|
||||||
# init torch model
|
# init torch model
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
|
|
||||||
# restore model
|
# restore model
|
||||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
|
|
@ -208,7 +208,7 @@ def main():
|
||||||
if args.vocoder_name is not None and not args.vocoder_path:
|
if args.vocoder_name is not None and not args.vocoder_path:
|
||||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||||
|
|
||||||
# CASE3: set custome model paths
|
# CASE3: set custom model paths
|
||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
config_path = args.config_path
|
config_path = args.config_path
|
||||||
|
|
|
@ -17,6 +17,7 @@ from TTS.trainer import init_training
|
||||||
from TTS.tts.datasets import load_meta_data
|
from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.training import NoamLR, check_update
|
from TTS.utils.training import NoamLR, check_update
|
||||||
|
|
||||||
|
@ -115,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||||
"step_time": step_time,
|
"step_time": step_time,
|
||||||
"avg_loader_time": avg_loader_time,
|
"avg_loader_time": avg_loader_time,
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_epoch_stats(global_step, train_stats)
|
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||||
figures = {
|
figures = {
|
||||||
# FIXME: not constant
|
# FIXME: not constant
|
||||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
dashboard_logger.train_figures(global_step, figures)
|
||||||
|
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
|
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = load_fsspec(args.restore_path)
|
||||||
try:
|
try:
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
|
@ -207,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
try:
|
try:
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
remove_experiment_folder(output_path)
|
remove_experiment_folder(output_path)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import yaml
|
import yaml
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
|
||||||
def read_json_with_comments(json_path):
|
def read_json_with_comments(json_path):
|
||||||
"""for backward compat."""
|
"""for backward compat."""
|
||||||
# fallback to json
|
# fallback to json
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
input_str = f.read()
|
||||||
# handle comments
|
# handle comments
|
||||||
input_str = re.sub(r"\\\n", "", input_str)
|
input_str = re.sub(r"\\\n", "", input_str)
|
||||||
|
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
ext = os.path.splitext(config_path)[1]
|
ext = os.path.splitext(config_path)[1]
|
||||||
if ext in (".yml", ".yaml"):
|
if ext in (".yml", ".yaml"):
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
elif ext == ".json":
|
elif ext == ".json":
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
data = json.load(f)
|
||||||
data = json.loads(input_str)
|
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
# backwards compat.
|
# backwards compat.
|
||||||
data = read_json_with_comments(config_path)
|
data = read_json_with_comments(config_path)
|
||||||
|
|
|
@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
|
||||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||||
do_trim_silence (bool):
|
do_trim_silence (bool):
|
||||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||||
|
do_amp_to_db_mel (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
power (float):
|
power (float):
|
||||||
|
@ -79,7 +83,7 @@ class BaseAudioConfig(Coqpit):
|
||||||
preemphasis: float = 0.0
|
preemphasis: float = 0.0
|
||||||
ref_level_db: int = 20
|
ref_level_db: int = 20
|
||||||
do_sound_norm: bool = False
|
do_sound_norm: bool = False
|
||||||
log_func = "np.log10"
|
log_func: str = "np.log10"
|
||||||
# silence trimming
|
# silence trimming
|
||||||
do_trim_silence: bool = True
|
do_trim_silence: bool = True
|
||||||
trim_db: int = 45
|
trim_db: int = 45
|
||||||
|
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
|
||||||
mel_fmin: float = 0.0
|
mel_fmin: float = 0.0
|
||||||
mel_fmax: float = None
|
mel_fmax: float = None
|
||||||
spec_gain: int = 20
|
spec_gain: int = 20
|
||||||
|
do_amp_to_db_linear: bool = True
|
||||||
|
do_amp_to_db_mel: bool = True
|
||||||
# normalization params
|
# normalization params
|
||||||
signal_norm: bool = True
|
signal_norm: bool = True
|
||||||
min_level_db: int = -100
|
min_level_db: int = -100
|
||||||
|
@ -182,51 +188,87 @@ class BaseTrainingConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
model (str):
|
model (str):
|
||||||
Name of the model that is used in the training.
|
Name of the model that is used in the training.
|
||||||
|
|
||||||
run_name (str):
|
run_name (str):
|
||||||
Name of the experiment. This prefixes the output folder name.
|
Name of the experiment. This prefixes the output folder name.
|
||||||
|
|
||||||
run_description (str):
|
run_description (str):
|
||||||
Short description of the experiment.
|
Short description of the experiment.
|
||||||
|
|
||||||
epochs (int):
|
epochs (int):
|
||||||
Number training epochs. Defaults to 10000.
|
Number training epochs. Defaults to 10000.
|
||||||
|
|
||||||
batch_size (int):
|
batch_size (int):
|
||||||
Training batch size.
|
Training batch size.
|
||||||
|
|
||||||
eval_batch_size (int):
|
eval_batch_size (int):
|
||||||
Validation batch size.
|
Validation batch size.
|
||||||
|
|
||||||
mixed_precision (bool):
|
mixed_precision (bool):
|
||||||
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
||||||
it may also cause numerical unstability in some cases.
|
it may also cause numerical unstability in some cases.
|
||||||
|
|
||||||
|
scheduler_after_epoch (bool):
|
||||||
|
If true, run the scheduler step after each epoch else run it after each model step.
|
||||||
|
|
||||||
run_eval (bool):
|
run_eval (bool):
|
||||||
Enable / Disable evaluation (validation) run. Defaults to True.
|
Enable / Disable evaluation (validation) run. Defaults to True.
|
||||||
|
|
||||||
test_delay_epochs (int):
|
test_delay_epochs (int):
|
||||||
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
||||||
results, hence waiting for a couple of epochs might save some time.
|
results, hence waiting for a couple of epochs might save some time.
|
||||||
|
|
||||||
print_eval (bool):
|
print_eval (bool):
|
||||||
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
||||||
the end of the evaluation. Default to ```False```.
|
the end of the evaluation. Default to ```False```.
|
||||||
|
|
||||||
print_step (int):
|
print_step (int):
|
||||||
Number of steps required to print the next training log.
|
Number of steps required to print the next training log.
|
||||||
tb_plot_step (int):
|
|
||||||
|
log_dashboard (str): "tensorboard" or "wandb"
|
||||||
|
Set the experiment tracking tool
|
||||||
|
|
||||||
|
plot_step (int):
|
||||||
Number of steps required to log training on Tensorboard.
|
Number of steps required to log training on Tensorboard.
|
||||||
tb_model_param_stats (bool):
|
|
||||||
|
model_param_stats (bool):
|
||||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||||
Defaults to ```False```.
|
Defaults to ```False```.
|
||||||
|
|
||||||
|
project_name (str):
|
||||||
|
Name of the project. Defaults to config.model
|
||||||
|
|
||||||
|
wandb_entity (str):
|
||||||
|
Name of W&B entity/team. Enables collaboration across a team or org.
|
||||||
|
|
||||||
|
log_model_step (int):
|
||||||
|
Number of steps required to log a checkpoint as W&B artifact
|
||||||
|
|
||||||
save_step (int):ipt
|
save_step (int):ipt
|
||||||
Number of steps required to save the next checkpoint.
|
Number of steps required to save the next checkpoint.
|
||||||
|
|
||||||
checkpoint (bool):
|
checkpoint (bool):
|
||||||
Enable / Disable checkpointing.
|
Enable / Disable checkpointing.
|
||||||
|
|
||||||
keep_all_best (bool):
|
keep_all_best (bool):
|
||||||
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
||||||
to ```False```.
|
to ```False```.
|
||||||
|
|
||||||
keep_after (int):
|
keep_after (int):
|
||||||
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
||||||
to 10000.
|
to 10000.
|
||||||
|
|
||||||
num_loader_workers (int):
|
num_loader_workers (int):
|
||||||
Number of workers for training time dataloader.
|
Number of workers for training time dataloader.
|
||||||
|
|
||||||
num_eval_loader_workers (int):
|
num_eval_loader_workers (int):
|
||||||
Number of workers for evaluation time dataloader.
|
Number of workers for evaluation time dataloader.
|
||||||
|
|
||||||
output_path (str):
|
output_path (str):
|
||||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
Path for training output folder, either a local file path or other
|
||||||
All training outputs are saved there.
|
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||||
|
S3 (s3://) paths. The nonexist part of the given path is created
|
||||||
|
automatically. All training artefacts are saved there.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = None
|
model: str = None
|
||||||
|
@ -237,14 +279,19 @@ class BaseTrainingConfig(Coqpit):
|
||||||
batch_size: int = None
|
batch_size: int = None
|
||||||
eval_batch_size: int = None
|
eval_batch_size: int = None
|
||||||
mixed_precision: bool = False
|
mixed_precision: bool = False
|
||||||
|
scheduler_after_epoch: bool = False
|
||||||
# eval params
|
# eval params
|
||||||
run_eval: bool = True
|
run_eval: bool = True
|
||||||
test_delay_epochs: int = 0
|
test_delay_epochs: int = 0
|
||||||
print_eval: bool = False
|
print_eval: bool = False
|
||||||
# logging
|
# logging
|
||||||
|
dashboard_logger: str = "tensorboard"
|
||||||
print_step: int = 25
|
print_step: int = 25
|
||||||
tb_plot_step: int = 100
|
plot_step: int = 100
|
||||||
tb_model_param_stats: bool = False
|
model_param_stats: bool = False
|
||||||
|
project_name: str = None
|
||||||
|
log_model_step: int = None
|
||||||
|
wandb_entity: str = None
|
||||||
# checkpointing
|
# checkpointing
|
||||||
save_step: int = 10000
|
save_step: int = 10000
|
||||||
checkpoint: bool = True
|
checkpoint: bool = True
|
||||||
|
|
|
@ -103,8 +103,8 @@ synthesizer = Synthesizer(
|
||||||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||||
)
|
)
|
||||||
|
|
||||||
use_multi_speaker = synthesizer.tts_model.speaker_manager is not None and synthesizer.tts_model.num_speakers > 1
|
use_multi_speaker = hasattr(synthesizer.tts_model, "speaker_manager") and synthesizer.tts_model.num_speakers > 1
|
||||||
speaker_manager = synthesizer.tts_model.speaker_manager if hasattr(synthesizer.tts_model, "speaker_manager") else None
|
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
|
||||||
# TODO: set this from SpeakerManager
|
# TODO: set this from SpeakerManager
|
||||||
use_gst = synthesizer.tts_config.get("use_gst", False)
|
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class LSTMWithProjection(nn.Module):
|
class LSTMWithProjection(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, proj_size):
|
def __init__(self, input_size, hidden_size, proj_size):
|
||||||
|
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
# pylint: disable=unused-argument, redefined-builtin
|
# pylint: disable=unused-argument, redefined-builtin
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.cuda()
|
self.cuda()
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
|
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.cuda()
|
self.cuda()
|
||||||
|
|
|
@ -6,11 +6,11 @@ import re
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
|
from TTS.utils.io import save_fsspec
|
||||||
|
|
||||||
|
|
||||||
class Storage(object):
|
class Storage(object):
|
||||||
|
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
|
||||||
"loss": model_loss,
|
"loss": model_loss,
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
torch.save(state, checkpoint_path)
|
save_fsspec(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||||
|
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
|
||||||
bestmodel_path = "best_model.pth.tar"
|
bestmodel_path = "best_model.pth.tar"
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
save_fsspec(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
from TTS.utils.io import save_fsspec
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||||
|
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||||
"loss": model_loss,
|
"loss": model_loss,
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
torch.save(state, checkpoint_path)
|
save_fsspec(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||||
|
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
|
||||||
bestmodel_path = "best_model.pth.tar"
|
bestmodel_path = "best_model.pth.tar"
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
save_fsspec(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
272
TTS/trainer.py
272
TTS/trainer.py
|
@ -1,8 +1,8 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import glob
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
|
@ -12,7 +12,9 @@ import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -29,18 +31,20 @@ from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
KeepAverage,
|
KeepAverage,
|
||||||
count_parameters,
|
count_parameters,
|
||||||
create_experiment_folder,
|
get_experiment_folder_path,
|
||||||
get_git_branch,
|
get_git_branch,
|
||||||
remove_experiment_folder,
|
remove_experiment_folder,
|
||||||
set_init_dict,
|
set_init_dict,
|
||||||
to_cuda,
|
to_cuda,
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
|
||||||
|
multiprocessing.set_start_method("fork")
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
if platform.system() != "Windows":
|
||||||
# https://github.com/pytorch/pytorch/issues/973
|
# https://github.com/pytorch/pytorch/issues/973
|
||||||
import resource
|
import resource
|
||||||
|
@ -48,6 +52,7 @@ if platform.system() != "Windows":
|
||||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||||
|
|
||||||
|
|
||||||
if is_apex_available():
|
if is_apex_available():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
@ -87,7 +92,7 @@ class Trainer:
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
c_logger: ConsoleLogger = None,
|
c_logger: ConsoleLogger = None,
|
||||||
tb_logger: TensorboardLogger = None,
|
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
||||||
model: nn.Module = None,
|
model: nn.Module = None,
|
||||||
cudnn_benchmark: bool = False,
|
cudnn_benchmark: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -112,7 +117,7 @@ class Trainer:
|
||||||
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
|
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
|
||||||
console logger is used. Defaults to None.
|
console logger is used. Defaults to None.
|
||||||
|
|
||||||
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||||
|
@ -134,8 +139,8 @@ class Trainer:
|
||||||
Running trainer on a config.
|
Running trainer on a config.
|
||||||
|
|
||||||
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
||||||
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
>>> trainer.fit()
|
>>> trainer.fit()
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
|
@ -148,22 +153,24 @@ class Trainer:
|
||||||
|
|
||||||
# set and initialize Pytorch runtime
|
# set and initialize Pytorch runtime
|
||||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
# parse config from console arguments
|
# parse config from console arguments
|
||||||
config, output_path, _, c_logger, tb_logger = process_args(args)
|
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||||
|
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.config.output_log_path = output_path
|
||||||
# init loggers
|
# init loggers
|
||||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||||
if tb_logger is None:
|
self.dashboard_logger = dashboard_logger
|
||||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
|
||||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
if self.dashboard_logger is None:
|
||||||
else:
|
self.dashboard_logger = init_logger(config)
|
||||||
self.tb_logger = tb_logger
|
|
||||||
|
if not self.config.log_model_step:
|
||||||
|
self.config.log_model_step = self.config.save_step
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
self._setup_logger_config(log_file)
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
|
@ -173,7 +180,6 @@ class Trainer:
|
||||||
self.best_loss = float("inf")
|
self.best_loss = float("inf")
|
||||||
self.train_loader = None
|
self.train_loader = None
|
||||||
self.eval_loader = None
|
self.eval_loader = None
|
||||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
|
||||||
|
|
||||||
self.keep_avg_train = None
|
self.keep_avg_train = None
|
||||||
self.keep_avg_eval = None
|
self.keep_avg_eval = None
|
||||||
|
@ -184,7 +190,7 @@ class Trainer:
|
||||||
# init audio processor
|
# init audio processor
|
||||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||||
|
|
||||||
# load dataset samples
|
# load data samples
|
||||||
# TODO: refactor this
|
# TODO: refactor this
|
||||||
if "datasets" in self.config:
|
if "datasets" in self.config:
|
||||||
# load data for `tts` models
|
# load data for `tts` models
|
||||||
|
@ -205,6 +211,10 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.model = self.get_model(self.config)
|
self.model = self.get_model(self.config)
|
||||||
|
|
||||||
|
# init multispeaker settings of the model
|
||||||
|
if hasattr(self.model, "init_multispeaker"):
|
||||||
|
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
|
||||||
|
@ -274,9 +284,9 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
# TODO: better model setup
|
# TODO: better model setup
|
||||||
try:
|
try:
|
||||||
model = setup_tts_model(config)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
model = setup_vocoder_model(config)
|
model = setup_vocoder_model(config)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
model = setup_tts_model(config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def restore_model(
|
def restore_model(
|
||||||
|
@ -309,7 +319,7 @@ class Trainer:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
checkpoint = torch.load(restore_path)
|
checkpoint = load_fsspec(restore_path)
|
||||||
try:
|
try:
|
||||||
print(" > Restoring Model...")
|
print(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
@ -417,7 +427,7 @@ class Trainer:
|
||||||
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
optimizer_idx: int = None,
|
optimizer_idx: int = None,
|
||||||
) -> Tuple[Dict, Dict, int, torch.Tensor]:
|
) -> Tuple[Dict, Dict, int]:
|
||||||
"""Perform a forward - backward pass and run the optimizer.
|
"""Perform a forward - backward pass and run the optimizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -426,7 +436,7 @@ class Trainer:
|
||||||
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
||||||
scaler (AMPScaler): AMP scaler.
|
scaler (AMPScaler): AMP scaler.
|
||||||
criterion (nn.Module): Model's criterion.
|
criterion (nn.Module): Model's criterion.
|
||||||
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
|
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
|
||||||
config (Coqpit): Model config.
|
config (Coqpit): Model config.
|
||||||
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
||||||
|
|
||||||
|
@ -436,6 +446,7 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
# zero-out optimizer
|
# zero-out optimizer
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -448,11 +459,11 @@ class Trainer:
|
||||||
# skip the rest
|
# skip the rest
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
return None, {}, step_time, 0
|
return None, {}, step_time
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||||
|
|
||||||
# set gradient clipping threshold
|
# set gradient clipping threshold
|
||||||
if "grad_clip" in config and config.grad_clip is not None:
|
if "grad_clip" in config and config.grad_clip is not None:
|
||||||
|
@ -463,7 +474,6 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
grad_clip = 0.0 # meaning no gradient clipping
|
grad_clip = 0.0 # meaning no gradient clipping
|
||||||
|
|
||||||
# TODO: compute grad norm
|
|
||||||
if grad_clip <= 0:
|
if grad_clip <= 0:
|
||||||
grad_norm = 0
|
grad_norm = 0
|
||||||
|
|
||||||
|
@ -474,15 +484,17 @@ class Trainer:
|
||||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(optimizer),
|
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||||
grad_clip,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# model optimizer step in mixed precision mode
|
# model optimizer step in mixed precision mode
|
||||||
scaler.scale(loss_dict["loss"]).backward()
|
scaler.scale(loss_dict["loss"]).backward()
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
if grad_clip > 0:
|
if grad_clip > 0:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
scaler.unscale_(optimizer)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||||
|
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||||
|
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||||
|
grad_norm = 0
|
||||||
scale_prev = scaler.get_scale()
|
scale_prev = scaler.get_scale()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
@ -491,13 +503,13 @@ class Trainer:
|
||||||
# main model optimizer step
|
# main model optimizer step
|
||||||
loss_dict["loss"].backward()
|
loss_dict["loss"].backward()
|
||||||
if grad_clip > 0:
|
if grad_clip > 0:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
if scheduler is not None and update_lr_scheduler:
|
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# detach losses
|
# detach losses
|
||||||
|
@ -505,7 +517,9 @@ class Trainer:
|
||||||
if optimizer_idx is not None:
|
if optimizer_idx is not None:
|
||||||
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
||||||
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
||||||
return outputs, loss_dict, step_time, grad_norm
|
else:
|
||||||
|
loss_dict["grad_norm"] = grad_norm
|
||||||
|
return outputs, loss_dict, step_time
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||||
|
@ -544,11 +558,10 @@ class Trainer:
|
||||||
|
|
||||||
# conteainers to hold model outputs and losses for each optimizer.
|
# conteainers to hold model outputs and losses for each optimizer.
|
||||||
outputs_per_optimizer = None
|
outputs_per_optimizer = None
|
||||||
log_dict = {}
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
if not isinstance(self.optimizer, list):
|
if not isinstance(self.optimizer, list):
|
||||||
# training with a single optimizer
|
# training with a single optimizer
|
||||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
outputs, loss_dict_new, step_time = self._optimize(
|
||||||
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
||||||
)
|
)
|
||||||
loss_dict.update(loss_dict_new)
|
loss_dict.update(loss_dict_new)
|
||||||
|
@ -560,25 +573,36 @@ class Trainer:
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||||
scheduler = self.scheduler[idx]
|
scheduler = self.scheduler[idx]
|
||||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
outputs, loss_dict_new, step_time = self._optimize(
|
||||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||||
)
|
)
|
||||||
# skip the rest if the model returns None
|
# skip the rest if the model returns None
|
||||||
total_step_time += step_time
|
total_step_time += step_time
|
||||||
outputs_per_optimizer[idx] = outputs
|
outputs_per_optimizer[idx] = outputs
|
||||||
|
# merge loss_dicts from each optimizer
|
||||||
|
# rename duplicates with the optimizer idx
|
||||||
# if None, model skipped this optimizer
|
# if None, model skipped this optimizer
|
||||||
if loss_dict_new is not None:
|
if loss_dict_new is not None:
|
||||||
loss_dict.update(loss_dict_new)
|
for k, v in loss_dict_new.items():
|
||||||
|
if k in loss_dict:
|
||||||
|
loss_dict[f"{k}-{idx}"] = v
|
||||||
|
else:
|
||||||
|
loss_dict[k] = v
|
||||||
|
step_time = total_step_time
|
||||||
outputs = outputs_per_optimizer
|
outputs = outputs_per_optimizer
|
||||||
|
|
||||||
# update avg stats
|
# update avg runtime stats
|
||||||
keep_avg_update = dict()
|
keep_avg_update = dict()
|
||||||
for key, value in log_dict.items():
|
|
||||||
keep_avg_update["avg_" + key] = value
|
|
||||||
keep_avg_update["avg_loader_time"] = loader_time
|
keep_avg_update["avg_loader_time"] = loader_time
|
||||||
keep_avg_update["avg_step_time"] = step_time
|
keep_avg_update["avg_step_time"] = step_time
|
||||||
self.keep_avg_train.update_values(keep_avg_update)
|
self.keep_avg_train.update_values(keep_avg_update)
|
||||||
|
|
||||||
|
# update avg loss stats
|
||||||
|
update_eval_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_eval_values["avg_" + key] = value
|
||||||
|
self.keep_avg_train.update_values(update_eval_values)
|
||||||
|
|
||||||
# print training progress
|
# print training progress
|
||||||
if self.total_steps_done % self.config.print_step == 0:
|
if self.total_steps_done % self.config.print_step == 0:
|
||||||
# log learning rates
|
# log learning rates
|
||||||
|
@ -590,33 +614,27 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
lrs = {"current_lr": current_lr}
|
lrs = {"current_lr": current_lr}
|
||||||
log_dict.update(lrs)
|
|
||||||
if grad_norm > 0:
|
|
||||||
log_dict.update({"grad_norm": grad_norm})
|
|
||||||
# log run-time stats
|
# log run-time stats
|
||||||
log_dict.update(
|
loss_dict.update(
|
||||||
{
|
{
|
||||||
"step_time": round(step_time, 4),
|
"step_time": round(step_time, 4),
|
||||||
"loader_time": round(loader_time, 4),
|
"loader_time": round(loader_time, 4),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.c_logger.print_train_step(
|
self.c_logger.print_train_step(
|
||||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
# reduce TB load and don't log every step
|
# reduce TB load and don't log every step
|
||||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
if self.total_steps_done % self.config.plot_step == 0:
|
||||||
iter_stats = log_dict
|
self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
|
||||||
iter_stats.update(loss_dict)
|
|
||||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
|
||||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||||
if self.config.checkpoint:
|
if self.config.checkpoint:
|
||||||
# checkpoint the model
|
# checkpoint the model
|
||||||
model_loss = (
|
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
|
||||||
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
|
|
||||||
)
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
self.config,
|
self.config,
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -625,8 +643,14 @@ class Trainer:
|
||||||
self.total_steps_done,
|
self.total_steps_done,
|
||||||
self.epochs_done,
|
self.epochs_done,
|
||||||
self.output_path,
|
self.output_path,
|
||||||
model_loss=model_loss,
|
model_loss=target_avg_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.total_steps_done % self.config.log_model_step == 0:
|
||||||
|
# log checkpoint as artifact
|
||||||
|
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
|
||||||
|
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||||
|
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = None, None
|
figures, audios = None, None
|
||||||
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
||||||
|
@ -634,11 +658,13 @@ class Trainer:
|
||||||
elif hasattr(self.model, "train_log"):
|
elif hasattr(self.model, "train_log"):
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||||
if figures is not None:
|
if figures is not None:
|
||||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
self.dashboard_logger.train_figures(self.total_steps_done, figures)
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
self.callbacks.on_train_step_end()
|
self.callbacks.on_train_step_end()
|
||||||
|
self.dashboard_logger.flush()
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
|
@ -663,9 +689,17 @@ class Trainer:
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
epoch_stats = {"epoch_time": epoch_time}
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||||
if self.config.tb_model_param_stats:
|
if self.config.model_param_stats:
|
||||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
self.logger.model_weights(self.model, self.total_steps_done)
|
||||||
|
# scheduler step after the epoch
|
||||||
|
if self.scheduler is not None and self.config.scheduler_after_epoch:
|
||||||
|
if isinstance(self.scheduler, list):
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step()
|
||||||
|
else:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_eval_step(
|
def _model_eval_step(
|
||||||
|
@ -701,19 +735,22 @@ class Trainer:
|
||||||
Tuple[Dict, Dict]: Model outputs and losses.
|
Tuple[Dict, Dict]: Model outputs and losses.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs_per_optimizer = None
|
outputs = []
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
if not isinstance(self.optimizer, list):
|
if not isinstance(self.optimizer, list):
|
||||||
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
||||||
else:
|
else:
|
||||||
outputs_per_optimizer = [None] * len(self.optimizer)
|
outputs = [None] * len(self.optimizer)
|
||||||
for idx, _ in enumerate(self.optimizer):
|
for idx, _ in enumerate(self.optimizer):
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||||
outputs_per_optimizer[idx] = outputs
|
outputs[idx] = outputs_
|
||||||
|
|
||||||
if loss_dict_new is not None:
|
if loss_dict_new is not None:
|
||||||
|
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
|
||||||
loss_dict.update(loss_dict_new)
|
loss_dict.update(loss_dict_new)
|
||||||
outputs = outputs_per_optimizer
|
|
||||||
|
loss_dict = self._detach_loss_dict(loss_dict)
|
||||||
|
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_eval_values = dict()
|
update_eval_values = dict()
|
||||||
|
@ -755,28 +792,35 @@ class Trainer:
|
||||||
elif hasattr(self.model, "eval_log"):
|
elif hasattr(self.model, "eval_log"):
|
||||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||||
if figures is not None:
|
if figures is not None:
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||||
|
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run"):
|
||||||
|
if self.eval_loader is None:
|
||||||
|
self.eval_loader = self.get_eval_dataloader(
|
||||||
|
self.ap,
|
||||||
|
self.data_eval,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||||
else:
|
else:
|
||||||
figures, audios = self.model.test_run(self.ap)
|
figures, audios = self.model.test_run(self.ap)
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||||
|
|
||||||
def _fit(self) -> None:
|
def _fit(self) -> None:
|
||||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
@ -802,10 +846,13 @@ class Trainer:
|
||||||
"""Where the ✨️magic✨️ happens..."""
|
"""Where the ✨️magic✨️ happens..."""
|
||||||
try:
|
try:
|
||||||
self._fit()
|
self._fit()
|
||||||
|
self.dashboard_logger.finish()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.callbacks.on_keyboard_interrupt()
|
self.callbacks.on_keyboard_interrupt()
|
||||||
# if the output folder is empty remove the run.
|
# if the output folder is empty remove the run.
|
||||||
remove_experiment_folder(self.output_path)
|
remove_experiment_folder(self.output_path)
|
||||||
|
# finish the wandb run and sync data
|
||||||
|
self.dashboard_logger.finish()
|
||||||
# stop without error signal
|
# stop without error signal
|
||||||
try:
|
try:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -816,10 +863,33 @@ class Trainer:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||||
|
"""Pick the target loss to compare models"""
|
||||||
|
target_avg_loss = None
|
||||||
|
|
||||||
|
# return if target loss defined in the model config
|
||||||
|
if "target_loss" in self.config and self.config.target_loss:
|
||||||
|
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||||
|
|
||||||
|
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||||
|
if isinstance(self.optimizer, list):
|
||||||
|
target_avg_loss = 0
|
||||||
|
for idx in range(len(self.optimizer)):
|
||||||
|
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||||
|
target_avg_loss /= len(self.optimizer)
|
||||||
|
else:
|
||||||
|
target_avg_loss = keep_avg_target["avg_loss"]
|
||||||
|
return target_avg_loss
|
||||||
|
|
||||||
def save_best_model(self) -> None:
|
def save_best_model(self) -> None:
|
||||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||||
|
|
||||||
|
# set the target loss to choose the best model
|
||||||
|
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
|
||||||
|
|
||||||
|
# save the model and update the best_loss
|
||||||
self.best_loss = save_best_model(
|
self.best_loss = save_best_model(
|
||||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
target_loss_dict,
|
||||||
self.best_loss,
|
self.best_loss,
|
||||||
self.config,
|
self.config,
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -834,9 +904,16 @@ class Trainer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _setup_logger_config(log_file: str) -> None:
|
def _setup_logger_config(log_file: str) -> None:
|
||||||
logging.basicConfig(
|
handlers = [logging.StreamHandler()]
|
||||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
|
||||||
)
|
# Only add a log file if the output location is local due to poor
|
||||||
|
# support for writing logs to file-like objects.
|
||||||
|
parsed_url = urlparse(log_file)
|
||||||
|
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||||
|
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||||
|
handlers.append(logging.FileHandler(schemeless_path))
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_apex_available() -> bool:
|
def _is_apex_available() -> bool:
|
||||||
|
@ -920,28 +997,33 @@ class Trainer:
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
|
|
||||||
def init_arguments():
|
def getarguments():
|
||||||
train_config = TrainingArgs()
|
train_config = TrainingArgs()
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_last_checkpoint(path):
|
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||||
"""Get latest checkpoint or/and best model in path.
|
"""Get latest checkpoint or/and best model in path.
|
||||||
|
|
||||||
It is based on globbing for `*.pth.tar` and the RegEx
|
It is based on globbing for `*.pth.tar` and the RegEx
|
||||||
`(checkpoint|best_model)_([0-9]+)`.
|
`(checkpoint|best_model)_([0-9]+)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (list): Path to files to be compared.
|
path: Path to files to be compared.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no checkpoint or best_model files are found.
|
ValueError: If no checkpoint or best_model files are found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
last_checkpoint (str): Last checkpoint filename.
|
Path to the last checkpoint
|
||||||
|
Path to best checkpoint
|
||||||
"""
|
"""
|
||||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||||
|
scheme = urlparse(path).scheme
|
||||||
|
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||||
|
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||||
last_models = {}
|
last_models = {}
|
||||||
last_model_nums = {}
|
last_model_nums = {}
|
||||||
for key in ["checkpoint", "best_model"]:
|
for key in ["checkpoint", "best_model"]:
|
||||||
|
@ -963,7 +1045,7 @@ def get_last_checkpoint(path):
|
||||||
key_file_names = [fn for fn in file_names if key in fn]
|
key_file_names = [fn for fn in file_names if key in fn]
|
||||||
if last_model is None and len(key_file_names) > 0:
|
if last_model is None and len(key_file_names) > 0:
|
||||||
last_model = max(key_file_names, key=os.path.getctime)
|
last_model = max(key_file_names, key=os.path.getctime)
|
||||||
last_model_num = torch.load(last_model)["step"]
|
last_model_num = load_fsspec(last_model)["step"]
|
||||||
|
|
||||||
if last_model is not None:
|
if last_model is not None:
|
||||||
last_models[key] = last_model
|
last_models[key] = last_model
|
||||||
|
@ -997,8 +1079,8 @@ def process_args(args, config=None):
|
||||||
audio_path (str): Path to save generated test audios.
|
audio_path (str): Path to save generated test audios.
|
||||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||||
logging to the console.
|
logging to the console.
|
||||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
|
||||||
the TensorBoard logging.
|
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
- Interactive config definition.
|
- Interactive config definition.
|
||||||
|
@ -1012,6 +1094,7 @@ def process_args(args, config=None):
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
if not args.best_path:
|
if not args.best_path:
|
||||||
args.best_path = best_model
|
args.best_path = best_model
|
||||||
|
|
||||||
# init config if not already defined
|
# init config if not already defined
|
||||||
if config is None:
|
if config is None:
|
||||||
if args.config_path:
|
if args.config_path:
|
||||||
|
@ -1030,12 +1113,12 @@ def process_args(args, config=None):
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path, config.run_name)
|
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
|
config.output_log_path = experiment_path
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
dashboard_logger = None
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
os.makedirs(audio_path, exist_ok=True)
|
|
||||||
new_fields = {}
|
new_fields = {}
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
|
@ -1043,17 +1126,20 @@ def process_args(args, config=None):
|
||||||
# if model characters are not set in the config file
|
# if model characters are not set in the config file
|
||||||
# save the default set to the config file for future
|
# save the default set to the config file for future
|
||||||
# compatibility.
|
# compatibility.
|
||||||
if config.has("characters_config"):
|
if config.has("characters") and config.characters is None:
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
os.chmod(audio_path, 0o775)
|
|
||||||
os.chmod(experiment_path, 0o775)
|
dashboard_logger = init_logger(config)
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||||
|
|
||||||
|
|
||||||
|
def init_arguments():
|
||||||
|
train_config = TrainingArgs()
|
||||||
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||||
|
@ -1063,5 +1149,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||||
else:
|
else:
|
||||||
parser = init_arguments()
|
parser = init_arguments()
|
||||||
args = parser.parse_known_args()
|
args = parser.parse_known_args()
|
||||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config)
|
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
||||||
|
|
|
@ -13,12 +13,16 @@ class GSTConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
gst_style_input_wav (str):
|
gst_style_input_wav (str):
|
||||||
Path to the wav file used to define the style of the output speech at inference. Defaults to None.
|
Path to the wav file used to define the style of the output speech at inference. Defaults to None.
|
||||||
|
|
||||||
gst_style_input_weights (dict):
|
gst_style_input_weights (dict):
|
||||||
Defines the weights for each style token used at inference. Defaults to None.
|
Defines the weights for each style token used at inference. Defaults to None.
|
||||||
|
|
||||||
gst_embedding_dim (int):
|
gst_embedding_dim (int):
|
||||||
Defines the size of the GST embedding vector dimensions. Defaults to 256.
|
Defines the size of the GST embedding vector dimensions. Defaults to 256.
|
||||||
|
|
||||||
gst_num_heads (int):
|
gst_num_heads (int):
|
||||||
Number of attention heads used by the multi-head attention. Defaults to 4.
|
Number of attention heads used by the multi-head attention. Defaults to 4.
|
||||||
|
|
||||||
gst_num_style_tokens (int):
|
gst_num_style_tokens (int):
|
||||||
Number of style token vectors. Defaults to 10.
|
Number of style token vectors. Defaults to 10.
|
||||||
"""
|
"""
|
||||||
|
@ -51,17 +55,23 @@ class CharactersConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
pad (str):
|
pad (str):
|
||||||
characters in place of empty padding. Defaults to None.
|
characters in place of empty padding. Defaults to None.
|
||||||
|
|
||||||
eos (str):
|
eos (str):
|
||||||
characters showing the end of a sentence. Defaults to None.
|
characters showing the end of a sentence. Defaults to None.
|
||||||
|
|
||||||
bos (str):
|
bos (str):
|
||||||
characters showing the beginning of a sentence. Defaults to None.
|
characters showing the beginning of a sentence. Defaults to None.
|
||||||
|
|
||||||
characters (str):
|
characters (str):
|
||||||
character set used by the model. Characters not in this list are ignored when converting input text to
|
character set used by the model. Characters not in this list are ignored when converting input text to
|
||||||
a list of sequence IDs. Defaults to None.
|
a list of sequence IDs. Defaults to None.
|
||||||
|
|
||||||
punctuations (str):
|
punctuations (str):
|
||||||
characters considered as punctuation as parsing the input sentence. Defaults to None.
|
characters considered as punctuation as parsing the input sentence. Defaults to None.
|
||||||
|
|
||||||
phonemes (str):
|
phonemes (str):
|
||||||
characters considered as parsing phonemes. Defaults to None.
|
characters considered as parsing phonemes. Defaults to None.
|
||||||
|
|
||||||
unique (bool):
|
unique (bool):
|
||||||
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
|
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
|
||||||
models trained with character lists with duplicates.
|
models trained with character lists with duplicates.
|
||||||
|
@ -95,54 +105,78 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
Args:
|
Args:
|
||||||
audio (BaseAudioConfig):
|
audio (BaseAudioConfig):
|
||||||
Audio processor config object instance.
|
Audio processor config object instance.
|
||||||
|
|
||||||
use_phonemes (bool):
|
use_phonemes (bool):
|
||||||
enable / disable phoneme use.
|
enable / disable phoneme use.
|
||||||
|
|
||||||
use_espeak_phonemes (bool):
|
use_espeak_phonemes (bool):
|
||||||
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
|
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
|
||||||
|
|
||||||
compute_input_seq_cache (bool):
|
compute_input_seq_cache (bool):
|
||||||
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
|
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
|
||||||
the training, It allows faster data loader time and precise limitation with `max_seq_len` and
|
the training, It allows faster data loader time and precise limitation with `max_seq_len` and
|
||||||
`min_seq_len`.
|
`min_seq_len`.
|
||||||
|
|
||||||
text_cleaner (str):
|
text_cleaner (str):
|
||||||
Name of the text cleaner used for cleaning and formatting transcripts.
|
Name of the text cleaner used for cleaning and formatting transcripts.
|
||||||
|
|
||||||
enable_eos_bos_chars (bool):
|
enable_eos_bos_chars (bool):
|
||||||
enable / disable the use of eos and bos characters.
|
enable / disable the use of eos and bos characters.
|
||||||
|
|
||||||
test_senteces_file (str):
|
test_senteces_file (str):
|
||||||
Path to a txt file that has sentences used at test time. The file must have a sentence per line.
|
Path to a txt file that has sentences used at test time. The file must have a sentence per line.
|
||||||
|
|
||||||
phoneme_cache_path (str):
|
phoneme_cache_path (str):
|
||||||
Path to the output folder caching the computed phonemes for each sample.
|
Path to the output folder caching the computed phonemes for each sample.
|
||||||
|
|
||||||
characters (CharactersConfig):
|
characters (CharactersConfig):
|
||||||
Instance of a CharactersConfig class.
|
Instance of a CharactersConfig class.
|
||||||
|
|
||||||
batch_group_size (int):
|
batch_group_size (int):
|
||||||
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||||
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||||
prevent using the same batches for each epoch.
|
prevent using the same batches for each epoch.
|
||||||
|
|
||||||
loss_masking (bool):
|
loss_masking (bool):
|
||||||
enable / disable masking loss values against padded segments of samples in a batch.
|
enable / disable masking loss values against padded segments of samples in a batch.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum input sequence length to be used at training.
|
Minimum input sequence length to be used at training.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
|
|
||||||
compute_f0 (int):
|
compute_f0 (int):
|
||||||
(Not in use yet).
|
(Not in use yet).
|
||||||
|
|
||||||
|
compute_linear_spec (bool):
|
||||||
|
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||||
|
|
||||||
use_noise_augment (bool):
|
use_noise_augment (bool):
|
||||||
Augment the input audio with random noise.
|
Augment the input audio with random noise.
|
||||||
|
|
||||||
add_blank (bool):
|
add_blank (bool):
|
||||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||||
of slower run-time due to the longer input sequence.
|
of slower run-time due to the longer input sequence.
|
||||||
|
|
||||||
datasets (List[BaseDatasetConfig]):
|
datasets (List[BaseDatasetConfig]):
|
||||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||||
for training.
|
for training.
|
||||||
|
|
||||||
optimizer (str):
|
optimizer (str):
|
||||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||||
Defaults to ``.
|
Defaults to ``.
|
||||||
|
|
||||||
optimizer_params (dict):
|
optimizer_params (dict):
|
||||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||||
|
|
||||||
lr_scheduler (str):
|
lr_scheduler (str):
|
||||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||||
`TTS.utils.training`. Defaults to ``.
|
`TTS.utils.training`. Defaults to ``.
|
||||||
|
|
||||||
lr_scheduler_params (dict):
|
lr_scheduler_params (dict):
|
||||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||||
|
|
||||||
test_sentences (List[str]):
|
test_sentences (List[str]):
|
||||||
List of sentences to be used at testing. Defaults to '[]'
|
List of sentences to be used at testing. Defaults to '[]'
|
||||||
"""
|
"""
|
||||||
|
@ -166,6 +200,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
min_seq_len: int = 1
|
min_seq_len: int = 1
|
||||||
max_seq_len: int = float("inf")
|
max_seq_len: int = float("inf")
|
||||||
compute_f0: bool = False
|
compute_f0: bool = False
|
||||||
|
compute_linear_spec: bool = False
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
add_blank: bool = False
|
add_blank: bool = False
|
||||||
# dataset
|
# dataset
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.vits import VitsArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsConfig(BaseTTSConfig):
|
||||||
|
"""Defines parameters for VITS End2End TTS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name. Do not change unless you know what you are doing.
|
||||||
|
|
||||||
|
model_args (VitsArgs):
|
||||||
|
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||||
|
|
||||||
|
grad_clip (List):
|
||||||
|
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
|
||||||
|
|
||||||
|
lr_gen (float):
|
||||||
|
Initial learning rate for the generator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_disc (float):
|
||||||
|
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_scheduler_gen (str):
|
||||||
|
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_gen_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
lr_scheduler_disc (str):
|
||||||
|
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_disc_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
scheduler_after_epoch (bool):
|
||||||
|
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||||
|
|
||||||
|
optimizer (str):
|
||||||
|
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||||
|
`torch.optim.*`. Defaults to `AdamW`.
|
||||||
|
|
||||||
|
kl_loss_alpha (float):
|
||||||
|
Loss weight for KL loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
disc_loss_alpha (float):
|
||||||
|
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
gen_loss_alpha (float):
|
||||||
|
Loss weight for the generator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
feat_loss_alpha (float):
|
||||||
|
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
mel_loss_alpha (float):
|
||||||
|
Loss weight for the mel loss. Defaults to 45.0.
|
||||||
|
|
||||||
|
return_wav (bool):
|
||||||
|
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
compute_linear_spec (bool):
|
||||||
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
min_seq_len (int):
|
||||||
|
Minimum text length to be considered for training. Defaults to `13`.
|
||||||
|
|
||||||
|
max_seq_len (int):
|
||||||
|
Maximum text length to be considered for training. Defaults to `500`.
|
||||||
|
|
||||||
|
r (int):
|
||||||
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
add_blank (bool):
|
||||||
|
If true, a blank token is added in between every character. Defaults to `True`.
|
||||||
|
|
||||||
|
test_sentences (List[str]):
|
||||||
|
List of sentences to be used for testing.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs import VitsConfig
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "vits"
|
||||||
|
# model specific params
|
||||||
|
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
|
||||||
|
lr_gen: float = 0.0002
|
||||||
|
lr_disc: float = 0.0002
|
||||||
|
lr_scheduler_gen: str = "ExponentialLR"
|
||||||
|
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||||
|
lr_scheduler_disc: str = "ExponentialLR"
|
||||||
|
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||||
|
scheduler_after_epoch: bool = True
|
||||||
|
optimizer: str = "AdamW"
|
||||||
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
kl_loss_alpha: float = 1.0
|
||||||
|
disc_loss_alpha: float = 1.0
|
||||||
|
gen_loss_alpha: float = 1.0
|
||||||
|
feat_loss_alpha: float = 1.0
|
||||||
|
mel_loss_alpha: float = 45.0
|
||||||
|
|
||||||
|
# data loader params
|
||||||
|
return_wav: bool = True
|
||||||
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
min_seq_len: int = 13
|
||||||
|
max_seq_len: int = 500
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
add_blank: bool = True
|
||||||
|
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist.",
|
||||||
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
|
@ -23,7 +23,9 @@ class TTSDataset(Dataset):
|
||||||
ap: AudioProcessor,
|
ap: AudioProcessor,
|
||||||
meta_data: List[List],
|
meta_data: List[List],
|
||||||
characters: Dict = None,
|
characters: Dict = None,
|
||||||
|
custom_symbols: List = None,
|
||||||
add_blank: bool = False,
|
add_blank: bool = False,
|
||||||
|
return_wav: bool = False,
|
||||||
batch_group_size: int = 0,
|
batch_group_size: int = 0,
|
||||||
min_seq_len: int = 0,
|
min_seq_len: int = 0,
|
||||||
max_seq_len: int = float("inf"),
|
max_seq_len: int = float("inf"),
|
||||||
|
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||||
|
|
||||||
|
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||||
|
set of symbols need to pass it here. Defaults to `None`.
|
||||||
|
|
||||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||||
models achieve better results. Defaults to false.
|
models achieve better results. Defaults to false.
|
||||||
|
|
||||||
|
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||||
|
|
||||||
batch_group_size (int): Range of batch randomization after sorting
|
batch_group_size (int): Range of batch randomization after sorting
|
||||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||||
batch. Set 0 to disable. Defaults to 0.
|
batch. Set 0 to disable. Defaults to 0.
|
||||||
|
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
|
self.return_wav = return_wav
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.characters = characters
|
self.characters = characters
|
||||||
|
self.custom_symbols = custom_symbols
|
||||||
self.add_blank = add_blank
|
self.add_blank = add_blank
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
|
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
|
self.rescue_item_idx = 1
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
def _generate_and_cache_phoneme_sequence(
|
||||||
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
|
):
|
||||||
"""generate a phoneme sequence from text.
|
"""generate a phoneme sequence from text.
|
||||||
since the usage is for subsequent caching, we never add bos and
|
since the usage is for subsequent caching, we never add bos and
|
||||||
eos chars here. Instead we add those dynamically later; based on the
|
eos chars here. Instead we add those dynamically later; based on the
|
||||||
config option."""
|
config option."""
|
||||||
phonemes = phoneme_to_sequence(
|
phonemes = phoneme_to_sequence(
|
||||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
text,
|
||||||
|
[cleaners],
|
||||||
|
language=language,
|
||||||
|
enable_eos_bos=False,
|
||||||
|
custom_symbols=custom_symbols,
|
||||||
|
tp=characters,
|
||||||
|
add_blank=add_blank,
|
||||||
)
|
)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
np.save(cache_path, phonemes)
|
np.save(cache_path, phonemes)
|
||||||
|
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_or_generate_phoneme_sequence(
|
def _load_or_generate_phoneme_sequence(
|
||||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||||
):
|
):
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
|
||||||
|
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
|
||||||
phonemes = np.load(cache_path)
|
phonemes = np.load(cache_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, characters, add_blank
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
)
|
)
|
||||||
except (ValueError, IOError):
|
except (ValueError, IOError):
|
||||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, characters, add_blank
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
)
|
)
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||||
|
@ -173,6 +191,7 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
text, wav_file, speaker_name = item
|
text, wav_file, speaker_name = item
|
||||||
attn = None
|
attn = None
|
||||||
|
raw_text = text
|
||||||
|
|
||||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||||
|
|
||||||
|
@ -189,13 +208,19 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
text_to_sequence(
|
||||||
|
text,
|
||||||
|
[self.cleaners],
|
||||||
|
custom_symbols=self.custom_symbols,
|
||||||
|
tp=self.characters,
|
||||||
|
add_blank=self.add_blank,
|
||||||
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,9 +234,10 @@ class TTSDataset(Dataset):
|
||||||
# return a different sample if the phonemized
|
# return a different sample if the phonemized
|
||||||
# text is longer than the threshold
|
# text is longer than the threshold
|
||||||
# TODO: find a better fix
|
# TODO: find a better fix
|
||||||
return self.load_data(100)
|
return self.load_data(self.rescue_item_idx)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
|
"raw_text": raw_text,
|
||||||
"text": text,
|
"text": text,
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
|
@ -238,7 +264,13 @@ class TTSDataset(Dataset):
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
text, *_ = item
|
text, *_ = item
|
||||||
sequence = np.asarray(
|
sequence = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
text_to_sequence(
|
||||||
|
text,
|
||||||
|
[self.cleaners],
|
||||||
|
custom_symbols=self.custom_symbols,
|
||||||
|
tp=self.characters,
|
||||||
|
add_blank=self.add_blank,
|
||||||
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
self.items[idx][0] = sequence
|
self.items[idx][0] = sequence
|
||||||
|
@ -249,6 +281,7 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
]
|
]
|
||||||
|
@ -329,6 +362,7 @@ class TTSDataset(Dataset):
|
||||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
||||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||||
|
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
|
@ -347,6 +381,14 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
# lengths adjusted by the reduction factor
|
||||||
|
mel_lengths_adjusted = [
|
||||||
|
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||||
|
if m.shape[1] % self.outputs_per_step
|
||||||
|
else m.shape[1]
|
||||||
|
for m in mel
|
||||||
|
]
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
@ -385,6 +427,20 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
linear = None
|
linear = None
|
||||||
|
|
||||||
|
# format waveforms
|
||||||
|
wav_padded = None
|
||||||
|
if self.return_wav:
|
||||||
|
wav_lengths = [w.shape[0] for w in wav]
|
||||||
|
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||||
|
wav_lengths = torch.LongTensor(wav_lengths)
|
||||||
|
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
||||||
|
for i, w in enumerate(wav):
|
||||||
|
mel_length = mel_lengths_adjusted[i]
|
||||||
|
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||||
|
w = w[: mel_length * self.ap.hop_length]
|
||||||
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||||
|
wav_padded.transpose_(1, 2)
|
||||||
|
|
||||||
# collate attention alignments
|
# collate attention alignments
|
||||||
if batch[0]["attn"] is not None:
|
if batch[0]["attn"] is not None:
|
||||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||||
|
@ -397,6 +453,7 @@ class TTSDataset(Dataset):
|
||||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
attns = None
|
attns = None
|
||||||
|
# TODO: return dictionary
|
||||||
return (
|
return (
|
||||||
text,
|
text,
|
||||||
text_lenghts,
|
text_lenghts,
|
||||||
|
@ -409,6 +466,8 @@ class TTSDataset(Dataset):
|
||||||
d_vectors,
|
d_vectors,
|
||||||
speaker_ids,
|
speaker_ids,
|
||||||
attns,
|
attns,
|
||||||
|
wav_padded,
|
||||||
|
raw_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2(nn.Module):
|
||||||
|
"""Layer norm for the 2nd dimension of the input using torch primitive.
|
||||||
|
Args:
|
||||||
|
channels (int): number of channels (2nd dimension) of the input.
|
||||||
|
eps (float): to prevent 0 division
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- input: (B, C, T)
|
||||||
|
- output: (B, C, T)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||||
"""Normalize each channel separately over time and batch."""
|
"""Normalize each channel separately over time and batch."""
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
|
||||||
dropout_p (float): Dropout rate used after each conv layer.
|
dropout_p (float): Dropout rate used after each conv layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# class arguments
|
# class arguments
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
|
||||||
self.norm_2 = LayerNorm(hidden_channels)
|
self.norm_2 = LayerNorm(hidden_channels)
|
||||||
# output layer
|
# output layer
|
||||||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||||
|
if cond_channels is not None and cond_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask, g=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
- x_mask: :math:`[B, 1, T]`
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
"""
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(x * x_mask)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.norm_1(x)
|
x = self.norm_1(x)
|
||||||
|
|
|
@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
||||||
::
|
::
|
||||||
|
|
||||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||||
|---------------> conv1d_1x1 -----------------------|
|
|---------------> conv1d_1x1 ------------------|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input tensor channels.
|
in_channels (int): number of input tensor channels.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.glow import LayerNorm
|
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionMultiHeadAttention(nn.Module):
|
class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
|
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
|
||||||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
|
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
|
|
||||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
if causal:
|
||||||
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
self.padding = self._causal_padding
|
||||||
|
else:
|
||||||
|
self.padding = self._same_padding
|
||||||
|
|
||||||
|
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
|
||||||
|
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(self.padding(x * x_mask))
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.conv_2(x * x_mask)
|
x = self.conv_2(self.padding(x * x_mask))
|
||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
def _causal_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = self.kernel_size - 1
|
||||||
|
pad_r = 0
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, self._pad_shape(padding))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _same_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = (self.kernel_size - 1) // 2
|
||||||
|
pad_r = self.kernel_size // 2
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, self._pad_shape(padding))
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pad_shape(padding):
|
||||||
|
l = padding[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionTransformer(nn.Module):
|
class RelativePositionTransformer(nn.Module):
|
||||||
"""Transformer with Relative Potional Encoding.
|
"""Transformer with Relative Potional Encoding.
|
||||||
|
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
|
||||||
If default, relative encoding is disabled and it is a regular transformer.
|
If default, relative encoding is disabled and it is a regular transformer.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||||
|
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
|
||||||
|
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
out_channels,
|
out_channels: int,
|
||||||
hidden_channels,
|
hidden_channels: int,
|
||||||
hidden_channels_ffn,
|
hidden_channels_ffn: int,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
num_layers,
|
num_layers: int,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
rel_attn_window_size=None,
|
rel_attn_window_size: int = None,
|
||||||
input_length=None,
|
input_length: int = None,
|
||||||
|
layer_norm_type: str = "1",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
|
||||||
input_length=input_length,
|
input_length=input_length,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
if layer_norm_type == "1":
|
||||||
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
|
elif layer_norm_type == "2":
|
||||||
|
self.norm_layers_1.append(LayerNorm2(hidden_channels))
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Unknown layer norm type")
|
||||||
|
|
||||||
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
if layer_norm_type == "1":
|
||||||
|
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||||
|
elif layer_norm_type == "2":
|
||||||
|
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Unknown layer norm type")
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,11 +2,13 @@ import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
|
from TTS.utils.audio import TorchSTFT
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
|
||||||
+ self.mdn_alpha * mdn_loss
|
+ self.mdn_alpha * mdn_loss
|
||||||
)
|
)
|
||||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||||
|
|
||||||
|
|
||||||
|
class VitsGeneratorLoss(nn.Module):
|
||||||
|
def __init__(self, c: Coqpit):
|
||||||
|
super().__init__()
|
||||||
|
self.kl_loss_alpha = c.kl_loss_alpha
|
||||||
|
self.gen_loss_alpha = c.gen_loss_alpha
|
||||||
|
self.feat_loss_alpha = c.feat_loss_alpha
|
||||||
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
|
self.stft = TorchSTFT(
|
||||||
|
c.audio.fft_size,
|
||||||
|
c.audio.hop_length,
|
||||||
|
c.audio.win_length,
|
||||||
|
sample_rate=c.audio.sample_rate,
|
||||||
|
mel_fmin=c.audio.mel_fmin,
|
||||||
|
mel_fmax=c.audio.mel_fmax,
|
||||||
|
n_mels=c.audio.num_mels,
|
||||||
|
use_mel=True,
|
||||||
|
do_amp_to_db=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def feature_loss(feats_real, feats_generated):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(feats_real, feats_generated):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
rl = rl.float().detach()
|
||||||
|
gl = gl.float()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generator_loss(scores_fake):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in scores_fake:
|
||||||
|
dg = dg.float()
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||||
|
"""
|
||||||
|
z_p, logs_q: [b, h, t_t]
|
||||||
|
m_p, logs_p: [b, h, t_t]
|
||||||
|
"""
|
||||||
|
z_p = z_p.float()
|
||||||
|
logs_q = logs_q.float()
|
||||||
|
m_p = m_p.float()
|
||||||
|
logs_p = logs_p.float()
|
||||||
|
z_mask = z_mask.float()
|
||||||
|
|
||||||
|
kl = logs_p - logs_q - 0.5
|
||||||
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||||
|
kl = torch.sum(kl * z_mask)
|
||||||
|
l = kl / torch.sum(z_mask)
|
||||||
|
return l
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
waveform,
|
||||||
|
waveform_hat,
|
||||||
|
z_p,
|
||||||
|
logs_q,
|
||||||
|
m_p,
|
||||||
|
logs_p,
|
||||||
|
z_len,
|
||||||
|
scores_disc_fake,
|
||||||
|
feats_disc_fake,
|
||||||
|
feats_disc_real,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- wavefrom: :math:`[B, 1, T]`
|
||||||
|
- waveform_hat: :math:`[B, 1, T]`
|
||||||
|
- z_p: :math:`[B, C, T]`
|
||||||
|
- logs_q: :math:`[B, C, T]`
|
||||||
|
- m_p: :math:`[B, C, T]`
|
||||||
|
- logs_p: :math:`[B, C, T]`
|
||||||
|
- z_len: :math:`[B]`
|
||||||
|
- scores_disc_fake[i]: :math:`[B, C]`
|
||||||
|
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
|
||||||
|
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
|
||||||
|
"""
|
||||||
|
loss = 0.0
|
||||||
|
return_dict = {}
|
||||||
|
z_mask = sequence_mask(z_len).float()
|
||||||
|
# compute mel spectrograms from the waveforms
|
||||||
|
mel = self.stft(waveform)
|
||||||
|
mel_hat = self.stft(waveform_hat)
|
||||||
|
# compute losses
|
||||||
|
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||||
|
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
|
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||||
|
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||||
|
loss = loss_kl + loss_feat + loss_mel + loss_gen
|
||||||
|
# pass losses to the dict
|
||||||
|
return_dict["loss_gen"] = loss_gen
|
||||||
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
return_dict["loss_feat"] = loss_feat
|
||||||
|
return_dict["loss_mel"] = loss_mel
|
||||||
|
return_dict["loss"] = loss
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class VitsDiscriminatorLoss(nn.Module):
|
||||||
|
def __init__(self, c: Coqpit):
|
||||||
|
super().__init__()
|
||||||
|
self.disc_loss_alpha = c.disc_loss_alpha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def discriminator_loss(scores_real, scores_fake):
|
||||||
|
loss = 0
|
||||||
|
real_losses = []
|
||||||
|
fake_losses = []
|
||||||
|
for dr, dg in zip(scores_real, scores_fake):
|
||||||
|
dr = dr.float()
|
||||||
|
dg = dg.float()
|
||||||
|
real_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
fake_loss = torch.mean(dg ** 2)
|
||||||
|
loss += real_loss + fake_loss
|
||||||
|
real_losses.append(real_loss.item())
|
||||||
|
fake_losses.append(fake_loss.item())
|
||||||
|
|
||||||
|
return loss, real_losses, fake_losses
|
||||||
|
|
||||||
|
def forward(self, scores_disc_real, scores_disc_fake):
|
||||||
|
loss = 0.0
|
||||||
|
return_dict = {}
|
||||||
|
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||||
|
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||||
|
loss = loss + loss_disc
|
||||||
|
return_dict["loss_disc"] = loss_disc
|
||||||
|
return_dict["loss"] = loss
|
||||||
|
return return_dict
|
||||||
|
|
|
@ -8,10 +8,10 @@ class GST(nn.Module):
|
||||||
|
|
||||||
See https://arxiv.org/pdf/1803.09017"""
|
See https://arxiv.org/pdf/1803.09017"""
|
||||||
|
|
||||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
||||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim)
|
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim)
|
||||||
|
|
||||||
def forward(self, inputs, speaker_embedding=None):
|
def forward(self, inputs, speaker_embedding=None):
|
||||||
enc_out = self.encoder(inputs)
|
enc_out = self.encoder(inputs)
|
||||||
|
@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module):
|
||||||
class StyleTokenLayer(nn.Module):
|
class StyleTokenLayer(nn.Module):
|
||||||
"""NN Module attending to style tokens based on prosody encodings."""
|
"""NN Module attending to style tokens based on prosody encodings."""
|
||||||
|
|
||||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None):
|
def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.query_dim = embedding_dim // 2
|
self.query_dim = gst_embedding_dim // 2
|
||||||
|
|
||||||
if d_vector_dim:
|
if d_vector_dim:
|
||||||
self.query_dim += d_vector_dim
|
self.query_dim += d_vector_dim
|
||||||
|
|
||||||
self.key_dim = embedding_dim // num_heads
|
self.key_dim = gst_embedding_dim // num_heads
|
||||||
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||||
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
||||||
self.attention = MultiHeadAttention(
|
self.attention = MultiHeadAttention(
|
||||||
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
|
query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.modules.conv import Conv1d
|
||||||
|
|
||||||
|
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: discriminator scores.
|
||||||
|
List[Tensor]: list of features from the convolutiona layers.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = torch.nn.functional.leaky_relu(x, 0.1)
|
||||||
|
feat.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
feat.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
return x, feat
|
||||||
|
|
||||||
|
|
||||||
|
class VitsDiscriminator(nn.Module):
|
||||||
|
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
|
||||||
|
|
||||||
|
::
|
||||||
|
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|
||||||
|
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
|
||||||
|
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]: discriminator scores.
|
||||||
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||||
|
"""
|
||||||
|
scores, feats = self.mpd(x)
|
||||||
|
score_sd, feats_sd = self.sd(x)
|
||||||
|
return scores + [score_sd], feats + [feats_sd]
|
|
@ -0,0 +1,271 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.tts.layers.glow_tts.glow import WN
|
||||||
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pad_shape(pad_shape):
|
||||||
|
l = pad_shape[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_vocab: int,
|
||||||
|
out_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
hidden_channels_ffn: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dropout_p: float,
|
||||||
|
):
|
||||||
|
"""Text Encoder for VITS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_vocab (int): Number of characters for the embedding layer.
|
||||||
|
out_channels (int): Number of channels for the output.
|
||||||
|
hidden_channels (int): Number of channels for the hidden layers.
|
||||||
|
hidden_channels_ffn (int): Number of channels for the convolutional layers.
|
||||||
|
num_heads (int): Number of attention heads for the Transformer layers.
|
||||||
|
num_layers (int): Number of Transformer layers.
|
||||||
|
kernel_size (int): Kernel size for the FFN layers in Transformer network.
|
||||||
|
dropout_p (float): Dropout rate for the Transformer layers.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||||
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||||
|
|
||||||
|
self.encoder = RelativePositionTransformer(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
hidden_channels_ffn=hidden_channels_ffn,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
layer_norm_type="2",
|
||||||
|
rel_attn_window_size=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T]`
|
||||||
|
- x_length: :math:`[B]`
|
||||||
|
"""
|
||||||
|
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||||
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
|
||||||
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
return x, m, logs, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
dropout_p=0,
|
||||||
|
cond_channels=0,
|
||||||
|
mean_only=False,
|
||||||
|
):
|
||||||
|
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = channels // 2
|
||||||
|
self.mean_only = mean_only
|
||||||
|
# input layer
|
||||||
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
|
# coupling layers
|
||||||
|
self.enc = WN(
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
c_in_channels=cond_channels,
|
||||||
|
)
|
||||||
|
# output layer
|
||||||
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
|
# do nothing at first. This helps with training stability
|
||||||
|
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||||
|
self.post.weight.data.zero_()
|
||||||
|
self.post.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||||
|
h = self.pre(x0) * x_mask
|
||||||
|
h = self.enc(h, x_mask, g=g)
|
||||||
|
stats = self.post(h) * x_mask
|
||||||
|
if not self.mean_only:
|
||||||
|
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
|
||||||
|
else:
|
||||||
|
m = stats
|
||||||
|
log_scale = torch.zeros_like(m)
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
x1 = m + x1 * torch.exp(log_scale) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
logdet = torch.sum(log_scale, [1, 2])
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlocks(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation_rate: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_flows=4,
|
||||||
|
cond_channels=0,
|
||||||
|
):
|
||||||
|
"""Redisual Coupling blocks for VITS flow layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input and output tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden network channels.
|
||||||
|
kernel_size (int): Kernel size of the WaveNet layers.
|
||||||
|
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||||
|
num_layers (int): Number of the WaveNet layers.
|
||||||
|
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
|
||||||
|
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_flows = num_flows
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
for _ in range(num_flows):
|
||||||
|
self.flows.append(
|
||||||
|
ResidualCouplingBlock(
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
cond_channels=cond_channels,
|
||||||
|
mean_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
if not reverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosteriorEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation_rate: int,
|
||||||
|
num_layers: int,
|
||||||
|
cond_channels=0,
|
||||||
|
):
|
||||||
|
"""Posterior Encoder of VITS model.
|
||||||
|
|
||||||
|
::
|
||||||
|
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
out_channels (int): Number of output tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of the WaveNet convolution layers.
|
||||||
|
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||||
|
num_layers (int): Number of the WaveNet layers.
|
||||||
|
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.enc = WN(
|
||||||
|
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
|
||||||
|
)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_lengths: :math:`[B, 1]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
x = self.pre(x) * x_mask
|
||||||
|
x = self.enc(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
|
||||||
|
return z, mean, log_scale, x_mask
|
|
@ -0,0 +1,276 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||||
|
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
class DilatedDepthSeparableConv(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
|
||||||
|
"""Dilated Depth-wise Separable Convolution module.
|
||||||
|
|
||||||
|
::
|
||||||
|
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|
||||||
|
|-------------------------------------------------------------------------------------^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels ([type]): [description]
|
||||||
|
kernel_size ([type]): [description]
|
||||||
|
num_layers ([type]): [description]
|
||||||
|
dropout_p (float, optional): [description]. Defaults to 0.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.tensor: Network output masked by the input sequence mask.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.convs_sep = nn.ModuleList()
|
||||||
|
self.convs_1x1 = nn.ModuleList()
|
||||||
|
self.norms_1 = nn.ModuleList()
|
||||||
|
self.norms_2 = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
dilation = kernel_size ** i
|
||||||
|
padding = (kernel_size * dilation - dilation) // 2
|
||||||
|
self.convs_sep.append(
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||||
|
)
|
||||||
|
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||||
|
self.norms_1.append(LayerNorm2(channels))
|
||||||
|
self.norms_2.append(LayerNorm2(channels))
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + g
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
y = self.convs_sep[i](x * x_mask)
|
||||||
|
y = self.norms_1[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.convs_1x1[i](y)
|
||||||
|
y = self.norms_2[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.dropout(y)
|
||||||
|
x = x + y
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAffine(nn.Module):
|
||||||
|
"""Element-wise affine transform like no-population stats BatchNorm alternative.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input tensor channels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.translation = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||||
|
if not reverse:
|
||||||
|
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
|
||||||
|
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFlow(nn.Module):
|
||||||
|
"""Dilated depth separable convolutional based spline flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
hidden_channels (int): Number of in network channels.
|
||||||
|
kernel_size (int): Convolutional kernel size.
|
||||||
|
num_layers (int): Number of convolutional layers.
|
||||||
|
num_bins (int, optional): Number of spline bins. Defaults to 10.
|
||||||
|
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_bins=10,
|
||||||
|
tail_bound=5.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_bins = num_bins
|
||||||
|
self.tail_bound = tail_bound
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||||
|
self.proj.weight.data.zero_()
|
||||||
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||||
|
h = self.pre(x0)
|
||||||
|
h = self.convs(h, x_mask, g=g)
|
||||||
|
h = self.proj(h) * x_mask
|
||||||
|
|
||||||
|
b, c, t = x0.shape
|
||||||
|
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||||
|
|
||||||
|
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
|
||||||
|
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
|
||||||
|
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||||
|
|
||||||
|
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||||
|
x1,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=reverse,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=self.tail_bound,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.cat([x0, x1], 1) * x_mask
|
||||||
|
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
||||||
|
if not reverse:
|
||||||
|
return x, logdet
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDurationPredictor(nn.Module):
|
||||||
|
"""Stochastic duration predictor with Spline Flows.
|
||||||
|
|
||||||
|
It applies Variational Dequantization and Variationsl Data Augmentation.
|
||||||
|
|
||||||
|
Paper:
|
||||||
|
SDP: https://arxiv.org/pdf/2106.06103.pdf
|
||||||
|
Spline Flow: https://arxiv.org/abs/1906.04032
|
||||||
|
|
||||||
|
::
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
x -> TextCondEncoder() -> Flow() -> dr_hat
|
||||||
|
noise ----------------------^
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|---------------------|
|
||||||
|
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
|
||||||
|
d -> DurCondEncoder() -> ^ |
|
||||||
|
|------------------------------------------------------------------------------|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of convolutional layers.
|
||||||
|
dropout_p (float): Dropout rate.
|
||||||
|
num_flows (int, optional): Number of flow blocks. Defaults to 4.
|
||||||
|
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# condition encoder text
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
self.flows.append(ElementwiseAffine(2))
|
||||||
|
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||||
|
|
||||||
|
# condition encoder duration
|
||||||
|
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
|
||||||
|
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
self.post_flows = nn.ModuleList()
|
||||||
|
self.post_flows.append(ElementwiseAffine(2))
|
||||||
|
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||||
|
|
||||||
|
if cond_channels != 0 and cond_channels is not None:
|
||||||
|
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- dr: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C]`
|
||||||
|
"""
|
||||||
|
# condition encoder text
|
||||||
|
x = self.pre(x)
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
|
x = self.convs(x, x_mask)
|
||||||
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
flows = self.flows
|
||||||
|
assert dr is not None
|
||||||
|
|
||||||
|
# condition encoder duration
|
||||||
|
h = self.post_pre(dr)
|
||||||
|
h = self.post_convs(h, x_mask)
|
||||||
|
h = self.post_proj(h) * x_mask
|
||||||
|
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||||
|
z_q = noise
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
logdet_tot_q = 0.0
|
||||||
|
for idx, flow in enumerate(self.post_flows):
|
||||||
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
|
||||||
|
logdet_tot_q = logdet_tot_q + logdet_q
|
||||||
|
if idx > 0:
|
||||||
|
z_q = torch.flip(z_q, [1])
|
||||||
|
|
||||||
|
z_u, z_v = torch.split(z_q, [1, 1], 1)
|
||||||
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
|
z0 = (dr - u) * x_mask
|
||||||
|
|
||||||
|
# posterior encoder - neg log likelihood
|
||||||
|
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||||
|
nll_posterior_encoder = (
|
||||||
|
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||||
|
)
|
||||||
|
|
||||||
|
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
|
||||||
|
logdet_tot = torch.sum(-z0, [1, 2])
|
||||||
|
z = torch.cat([z0, z_v], 1)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
for idx, flow in enumerate(flows):
|
||||||
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
logdet_tot = logdet_tot + logdet
|
||||||
|
if idx > 0:
|
||||||
|
z = torch.flip(z, [1])
|
||||||
|
|
||||||
|
# flow layers - neg log likelihood
|
||||||
|
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
||||||
|
return nll_flow_layers + nll_posterior_encoder
|
||||||
|
|
||||||
|
flows = list(reversed(self.flows))
|
||||||
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
|
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||||
|
for flow in flows:
|
||||||
|
z = torch.flip(z, [1])
|
||||||
|
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
|
||||||
|
z0, _ = torch.split(z, [1, 1], 1)
|
||||||
|
logw = z0
|
||||||
|
return logw
|
|
@ -0,0 +1,203 @@
|
||||||
|
# adopted from https://github.com/bayesiains/nflows
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||||
|
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||||
|
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def piecewise_rational_quadratic_transform(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails=None,
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
|
||||||
|
if tails is None:
|
||||||
|
spline_fn = rational_quadratic_spline
|
||||||
|
spline_kwargs = {}
|
||||||
|
else:
|
||||||
|
spline_fn = unconstrained_rational_quadratic_spline
|
||||||
|
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||||
|
|
||||||
|
outputs, logabsdet = spline_fn(
|
||||||
|
inputs=inputs,
|
||||||
|
unnormalized_widths=unnormalized_widths,
|
||||||
|
unnormalized_heights=unnormalized_heights,
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
**spline_kwargs,
|
||||||
|
)
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||||
|
bin_locations[..., -1] += eps
|
||||||
|
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def unconstrained_rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||||
|
outside_interval_mask = ~inside_interval_mask
|
||||||
|
|
||||||
|
outputs = torch.zeros_like(inputs)
|
||||||
|
logabsdet = torch.zeros_like(inputs)
|
||||||
|
|
||||||
|
if tails == "linear":
|
||||||
|
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||||
|
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||||
|
unnormalized_derivatives[..., 0] = constant
|
||||||
|
unnormalized_derivatives[..., -1] = constant
|
||||||
|
|
||||||
|
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||||
|
logabsdet[outside_interval_mask] = 0
|
||||||
|
else:
|
||||||
|
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||||
|
|
||||||
|
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||||
|
inputs=inputs[inside_interval_mask],
|
||||||
|
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||||
|
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||||
|
inverse=inverse,
|
||||||
|
left=-tail_bound,
|
||||||
|
right=tail_bound,
|
||||||
|
bottom=-tail_bound,
|
||||||
|
top=tail_bound,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
left=0.0,
|
||||||
|
right=1.0,
|
||||||
|
bottom=0.0,
|
||||||
|
top=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||||
|
raise ValueError("Input to a transform is not within its domain")
|
||||||
|
|
||||||
|
num_bins = unnormalized_widths.shape[-1]
|
||||||
|
|
||||||
|
if min_bin_width * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin width too large for the number of bins")
|
||||||
|
if min_bin_height * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin height too large for the number of bins")
|
||||||
|
|
||||||
|
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||||
|
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||||
|
cumwidths = torch.cumsum(widths, dim=-1)
|
||||||
|
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumwidths = (right - left) * cumwidths + left
|
||||||
|
cumwidths[..., 0] = left
|
||||||
|
cumwidths[..., -1] = right
|
||||||
|
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||||
|
|
||||||
|
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||||
|
|
||||||
|
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||||
|
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||||
|
cumheights = torch.cumsum(heights, dim=-1)
|
||||||
|
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumheights = (top - bottom) * cumheights + bottom
|
||||||
|
cumheights[..., 0] = bottom
|
||||||
|
cumheights[..., -1] = top
|
||||||
|
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||||
|
else:
|
||||||
|
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||||
|
|
||||||
|
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||||
|
delta = heights / widths
|
||||||
|
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
a = (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) + input_heights * (input_delta - input_derivatives)
|
||||||
|
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
)
|
||||||
|
c = -input_delta * (inputs - input_cumheights)
|
||||||
|
|
||||||
|
discriminant = b.pow(2) - 4 * a * c
|
||||||
|
assert (discriminant >= 0).all()
|
||||||
|
|
||||||
|
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||||
|
outputs = root * input_bin_widths + input_cumwidths
|
||||||
|
|
||||||
|
theta_one_minus_theta = root * (1 - root)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
|
)
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * root.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - root).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, -logabsdet
|
||||||
|
else:
|
||||||
|
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||||
|
theta_one_minus_theta = theta * (1 - theta)
|
||||||
|
|
||||||
|
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
|
)
|
||||||
|
outputs = input_cumheights + numerator / denominator
|
||||||
|
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * theta.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - theta).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
|
@ -4,20 +4,23 @@ from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
def setup_model(config):
|
def setup_model(config):
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
|
|
||||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||||
# define set of characters used by the model
|
# define set of characters used by the model
|
||||||
if config.characters is not None:
|
if config.characters is not None:
|
||||||
# set characters from config
|
# set characters from config
|
||||||
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name
|
if hasattr(MyModel, "make_symbols"):
|
||||||
|
symbols = MyModel.make_symbols(config)
|
||||||
|
else:
|
||||||
|
symbols, phonemes = make_symbols(**config.characters)
|
||||||
else:
|
else:
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
|
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
if config.use_phonemes:
|
||||||
|
symbols = phonemes
|
||||||
# use default characters and assign them to config
|
# use default characters and assign them to config
|
||||||
config.characters = parse_symbols()
|
config.characters = parse_symbols()
|
||||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
|
||||||
# consider special `blank` character if `add_blank` is set True
|
# consider special `blank` character if `add_blank` is set True
|
||||||
num_chars = num_chars + getattr(config, "add_blank", False)
|
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||||
config.num_chars = num_chars
|
config.num_chars = num_chars
|
||||||
# compatibility fix
|
# compatibility fix
|
||||||
if "model_params" in config:
|
if "model_params" in config:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,9 +76,6 @@ class BaseTacotron(BaseTTS):
|
||||||
self.decoder_backward = None
|
self.decoder_backward = None
|
||||||
self.coarse_decoder = None
|
self.coarse_decoder = None
|
||||||
|
|
||||||
# init multi-speaker layers
|
|
||||||
self.init_multispeaker(config)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||||
|
@ -113,7 +111,7 @@ class BaseTacotron(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if "r" in state:
|
if "r" in state:
|
||||||
self.decoder.set_r(state["r"])
|
self.decoder.set_r(state["r"])
|
||||||
|
@ -236,6 +234,7 @@ class BaseTacotron(BaseTTS):
|
||||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||||
"""Compute global style token"""
|
"""Compute global style token"""
|
||||||
if isinstance(style_input, dict):
|
if isinstance(style_input, dict):
|
||||||
|
# multiply each style token with a weight
|
||||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||||
if speaker_embedding is not None:
|
if speaker_embedding is not None:
|
||||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||||
|
@ -247,8 +246,10 @@ class BaseTacotron(BaseTTS):
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
elif style_input is None:
|
elif style_input is None:
|
||||||
|
# ignore style token and return zero tensor
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||||
else:
|
else:
|
||||||
|
# compute style tokens
|
||||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
return get_speaker_manager(config, restore_path, data, out_path)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
`in_channels` size of the connected layers.
|
||||||
|
|
||||||
If you need a different behaviour, override this function for your model.
|
This implementation yields 3 possible outcomes:
|
||||||
|
|
||||||
|
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||||
|
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||||
|
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||||
|
`config.d_vector_dim` or 512.
|
||||||
|
|
||||||
|
You can override this function for new models.0
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
|
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
|
||||||
"""
|
"""
|
||||||
# init speaker manager
|
# init speaker manager
|
||||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
|
||||||
# init speaker embedding layer
|
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
if data is not None or self.speaker_manager.speaker_ids:
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
else:
|
||||||
|
self.num_speakers = (
|
||||||
|
config.num_speakers
|
||||||
|
if "num_speakers" in config and config.num_speakers != 0
|
||||||
|
else self.speaker_manager.num_speakers
|
||||||
|
)
|
||||||
|
|
||||||
|
# set ultimate speaker embedding size
|
||||||
|
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = (
|
self.embedded_speaker_dim = (
|
||||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||||
)
|
)
|
||||||
|
# init speaker embedding layer
|
||||||
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
|
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
|
||||||
text_input = batch[0]
|
text_input = batch[0]
|
||||||
text_lengths = batch[1]
|
text_lengths = batch[1]
|
||||||
speaker_names = batch[2]
|
speaker_names = batch[2]
|
||||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
linear_input = batch[3]
|
||||||
mel_input = batch[4]
|
mel_input = batch[4]
|
||||||
mel_lengths = batch[5]
|
mel_lengths = batch[5]
|
||||||
stop_targets = batch[6]
|
stop_targets = batch[6]
|
||||||
|
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
|
||||||
d_vectors = batch[8]
|
d_vectors = batch[8]
|
||||||
speaker_ids = batch[9]
|
speaker_ids = batch[9]
|
||||||
attn_mask = batch[10]
|
attn_mask = batch[10]
|
||||||
|
waveform = batch[11]
|
||||||
max_text_length = torch.max(text_lengths.float())
|
max_text_length = torch.max(text_lengths.float())
|
||||||
max_spec_length = torch.max(mel_lengths.float())
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
|
||||||
"max_text_length": float(max_text_length),
|
"max_text_length": float(max_text_length),
|
||||||
"max_spec_length": float(max_spec_length),
|
"max_spec_length": float(max_spec_length),
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
|
"waveform": waveform,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
|
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
|
||||||
speaker_id_mapping = None
|
speaker_id_mapping = None
|
||||||
d_vector_mapping = None
|
d_vector_mapping = None
|
||||||
|
|
||||||
|
# setup custom symbols if needed
|
||||||
|
custom_symbols = None
|
||||||
|
if hasattr(self, "make_symbols"):
|
||||||
|
custom_symbols = self.make_symbols(self.config)
|
||||||
|
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
text_cleaner=config.text_cleaner,
|
||||||
compute_linear_spec=config.model.lower() == "tacotron",
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||||
meta_data=data_items,
|
meta_data=data_items,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
characters=config.characters,
|
characters=config.characters,
|
||||||
|
custom_symbols=custom_symbols,
|
||||||
add_blank=config["add_blank"],
|
add_blank=config["add_blank"],
|
||||||
|
return_wav=config.return_wav if "return_wav" in config else False,
|
||||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
min_seq_len=config.min_seq_len,
|
min_seq_len=config.min_seq_len,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
|
@ -185,8 +213,21 @@ class BaseTTS(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.use_phonemes and config.compute_input_seq_cache:
|
if config.use_phonemes and config.compute_input_seq_cache:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
if hasattr(self, "eval_data_items") and is_eval:
|
||||||
dataset.compute_input_seq(config.num_loader_workers)
|
dataset.items = self.eval_data_items
|
||||||
|
elif hasattr(self, "train_data_items") and not is_eval:
|
||||||
|
dataset.items = self.train_data_items
|
||||||
|
else:
|
||||||
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
|
dataset.compute_input_seq(config.num_loader_workers)
|
||||||
|
|
||||||
|
# TODO: find a more efficient solution
|
||||||
|
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
||||||
|
if is_eval:
|
||||||
|
self.eval_data_items = dataset.items
|
||||||
|
else:
|
||||||
|
self.train_data_items = dataset.items
|
||||||
|
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
@ -216,7 +257,7 @@ class BaseTTS(BaseModel):
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = self.get_aux_input()
|
aux_inputs = self.get_aux_input()
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
wav, alignment, model_outputs, _ = synthesis(
|
outputs_dict = synthesis(
|
||||||
self,
|
self,
|
||||||
sen,
|
sen,
|
||||||
self.config,
|
self.config,
|
||||||
|
@ -228,9 +269,12 @@ class BaseTTS(BaseModel):
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=True,
|
use_griffin_lim=True,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
).values()
|
)
|
||||||
|
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
)
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||||
|
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||||
|
)
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
|
@ -12,14 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class GlowTTS(BaseTTS):
|
class GlowTTS(BaseTTS):
|
||||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
"""GlowTTS model.
|
||||||
|
|
||||||
Paper abstract:
|
Paper::
|
||||||
|
https://arxiv.org/abs/2005.11129
|
||||||
|
|
||||||
|
Paper abstract::
|
||||||
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
||||||
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
||||||
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
||||||
|
@ -144,7 +149,6 @@ class GlowTTS(BaseTTS):
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
|
@ -361,12 +365,49 @@ class GlowTTS(BaseTTS):
|
||||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
return figures, {"audio": train_audio}
|
return figures, {"audio": train_audio}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
return self.train_log(ap, batch, outputs)
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_run(self, ap):
|
||||||
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
|
"""
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self.get_aux_input()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
outputs = synthesis(
|
||||||
|
self,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
"cuda" in str(next(self.parameters()).device),
|
||||||
|
ap,
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||||
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
|
outputs["outputs"]["model_outputs"], ap, output_fig=False
|
||||||
|
)
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||||
|
return test_figures, test_audios
|
||||||
|
|
||||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||||
if y_max_length is not None:
|
if y_max_length is not None:
|
||||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||||
|
@ -382,7 +423,7 @@ class GlowTTS(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -105,7 +106,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
if isinstance(config.model_args.length_scale, int)
|
if isinstance(config.model_args.length_scale, int)
|
||||||
else config.model_args.length_scale
|
else config.model_args.length_scale
|
||||||
)
|
)
|
||||||
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
|
self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
config.model_args.hidden_channels,
|
config.model_args.hidden_channels,
|
||||||
config.model_args.hidden_channels,
|
config.model_args.hidden_channels,
|
||||||
|
@ -227,6 +228,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -306,7 +308,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -23,19 +23,19 @@ class Tacotron(BaseTacotron):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.num_chars, self.config = self.get_characters(config)
|
chars, self.config = self.get_characters(config)
|
||||||
|
config.num_chars = self.num_chars = len(chars)
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
# for fewer code change
|
# for fewer code change
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
# speaker embedding layer
|
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||||
if self.num_speakers > 1:
|
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||||
|
# on the number of speakers infered from the dataset.
|
||||||
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
|
||||||
if self.num_speakers > 1:
|
|
||||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||||
|
|
||||||
if self.use_gst:
|
if self.use_gst:
|
||||||
|
@ -75,13 +75,11 @@ class Tacotron(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
self.gst_layer = GST(
|
self.gst_layer = GST(
|
||||||
num_mel=self.decoder_output_dim,
|
num_mel=self.decoder_output_dim,
|
||||||
d_vector_dim=self.d_vector_dim
|
|
||||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
|
||||||
else None,
|
|
||||||
num_heads=self.gst.gst_num_heads,
|
num_heads=self.gst.gst_num_heads,
|
||||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
# backward pass decoder
|
# backward pass decoder
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
self._init_backward_decoder()
|
self._init_backward_decoder()
|
||||||
|
@ -106,7 +104,9 @@ class Tacotron(BaseTacotron):
|
||||||
self.max_decoder_steps,
|
self.max_decoder_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
|
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -115,6 +115,7 @@ class Tacotron(BaseTacotron):
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||||
"""
|
"""
|
||||||
|
aux_input = self._format_aux_input(aux_input)
|
||||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
inputs = self.embedding(text)
|
inputs = self.embedding(text)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -125,12 +126,10 @@ class Tacotron(BaseTacotron):
|
||||||
# global style token
|
# global style token
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
|
||||||
)
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1:
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||||
else:
|
else:
|
||||||
|
@ -182,7 +181,7 @@ class Tacotron(BaseTacotron):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
||||||
# reshape embedded_speakers
|
# reshape embedded_speakers
|
||||||
|
|
|
@ -23,7 +23,7 @@ class Tacotron2(BaseTacotron):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
chars, self.config = self.get_characters(config)
|
chars, self.config = self.get_characters(config)
|
||||||
self.num_chars = len(chars)
|
config.num_chars = len(chars)
|
||||||
self.decoder_output_dim = config.out_channels
|
self.decoder_output_dim = config.out_channels
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
|
@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron):
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
# speaker embedding layer
|
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||||
if self.num_speakers > 1:
|
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||||
|
# on the number of speakers infered from the dataset.
|
||||||
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
|
||||||
if self.num_speakers > 1:
|
|
||||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||||
|
|
||||||
if self.use_gst:
|
if self.use_gst:
|
||||||
|
@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron):
|
||||||
|
|
||||||
# base model layers
|
# base model layers
|
||||||
self.encoder = Encoder(self.encoder_in_features)
|
self.encoder = Encoder(self.encoder_in_features)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
self.decoder_in_features,
|
self.decoder_in_features,
|
||||||
self.decoder_output_dim,
|
self.decoder_output_dim,
|
||||||
|
@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
self.gst_layer = GST(
|
self.gst_layer = GST(
|
||||||
num_mel=self.decoder_output_dim,
|
num_mel=self.decoder_output_dim,
|
||||||
d_vector_dim=self.d_vector_dim
|
|
||||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
|
||||||
else None,
|
|
||||||
num_heads=self.gst.gst_num_heads,
|
num_heads=self.gst.gst_num_heads,
|
||||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||||
|
@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
|
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron):
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
|
||||||
)
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if self.num_speakers > 1:
|
if not self.use_d_vector_file:
|
||||||
if not self.use_d_vectors:
|
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||||
else:
|
else:
|
||||||
|
@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||||
# reshape embedded_speakers
|
# reshape embedded_speakers
|
||||||
if embedded_speakers.ndim == 1:
|
if embedded_speakers.ndim == 1:
|
||||||
|
|
|
@ -0,0 +1,767 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
||||||
|
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
|
||||||
|
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
||||||
|
"""Segment each sample in a batch based on the provided segment indices"""
|
||||||
|
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
index_start = segment_indices[i]
|
||||||
|
index_end = index_start + segment_size
|
||||||
|
segments[i] = x[i, :, index_start:index_end]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
|
||||||
|
"""Create random segments based on the input lengths."""
|
||||||
|
B, _, T = x.size()
|
||||||
|
if x_lengths is None:
|
||||||
|
x_lengths = T
|
||||||
|
max_idxs = x_lengths - segment_size + 1
|
||||||
|
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||||
|
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||||
|
ret = segment(x, ids_str, segment_size)
|
||||||
|
return ret, ids_str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsArgs(Coqpit):
|
||||||
|
"""VITS model arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
num_chars (int):
|
||||||
|
Number of characters in the vocabulary. Defaults to 100.
|
||||||
|
|
||||||
|
out_channels (int):
|
||||||
|
Number of output channels. Defaults to 513.
|
||||||
|
|
||||||
|
spec_segment_size (int):
|
||||||
|
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
Number of hidden channels of the model. Defaults to 192.
|
||||||
|
|
||||||
|
hidden_channels_ffn_text_encoder (int):
|
||||||
|
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
|
||||||
|
|
||||||
|
num_heads_text_encoder (int):
|
||||||
|
Number of attention heads of the text encoder transformer. Defaults to 2.
|
||||||
|
|
||||||
|
num_layers_text_encoder (int):
|
||||||
|
Number of transformer layers in the text encoder. Defaults to 6.
|
||||||
|
|
||||||
|
kernel_size_text_encoder (int):
|
||||||
|
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
|
||||||
|
|
||||||
|
dropout_p_text_encoder (float):
|
||||||
|
Dropout rate of the text encoder. Defaults to 0.1.
|
||||||
|
|
||||||
|
dropout_p_duration_predictor (float):
|
||||||
|
Dropout rate of the duration predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
kernel_size_posterior_encoder (int):
|
||||||
|
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
|
||||||
|
|
||||||
|
dilatation_posterior_encoder (int):
|
||||||
|
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
|
||||||
|
|
||||||
|
num_layers_posterior_encoder (int):
|
||||||
|
Number of posterior encoder's WaveNet layers. Defaults to 16.
|
||||||
|
|
||||||
|
kernel_size_flow (int):
|
||||||
|
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
|
||||||
|
|
||||||
|
dilatation_flow (int):
|
||||||
|
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
|
||||||
|
|
||||||
|
num_layers_flow (int):
|
||||||
|
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
|
||||||
|
|
||||||
|
resblock_type_decoder (str):
|
||||||
|
Type of the residual block in the decoder network. Defaults to "1".
|
||||||
|
|
||||||
|
resblock_kernel_sizes_decoder (List[int]):
|
||||||
|
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
|
||||||
|
|
||||||
|
resblock_dilation_sizes_decoder (List[List[int]]):
|
||||||
|
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
|
||||||
|
|
||||||
|
upsample_rates_decoder (List[int]):
|
||||||
|
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
|
||||||
|
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
|
||||||
|
|
||||||
|
upsample_initial_channel_decoder (int):
|
||||||
|
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
|
||||||
|
|
||||||
|
upsample_kernel_sizes_decoder (List[int]):
|
||||||
|
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
||||||
|
|
||||||
|
use_sdp (int):
|
||||||
|
Use Stochastic Duration Predictor. Defaults to True.
|
||||||
|
|
||||||
|
noise_scale (float):
|
||||||
|
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
|
||||||
|
|
||||||
|
inference_noise_scale (float):
|
||||||
|
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
|
||||||
|
|
||||||
|
length_scale (int):
|
||||||
|
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
|
||||||
|
|
||||||
|
noise_scale_dp (float):
|
||||||
|
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
|
||||||
|
|
||||||
|
inference_noise_scale_dp (float):
|
||||||
|
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
|
||||||
|
|
||||||
|
max_inference_len (int):
|
||||||
|
Maximum inference length to limit the memory use. Defaults to None.
|
||||||
|
|
||||||
|
init_discriminator (bool):
|
||||||
|
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
|
||||||
|
|
||||||
|
use_spectral_norm_disriminator (bool):
|
||||||
|
Use spectral normalization over weight norm in the discriminator. Defaults to False.
|
||||||
|
|
||||||
|
use_speaker_embedding (bool):
|
||||||
|
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
|
||||||
|
|
||||||
|
num_speakers (int):
|
||||||
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
speakers_file (str):
|
||||||
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||||
|
|
||||||
|
speaker_embedding_channels (int):
|
||||||
|
Number of speaker embedding channels. Defaults to 256.
|
||||||
|
|
||||||
|
use_d_vector_file (bool):
|
||||||
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
|
detach_dp_input (bool):
|
||||||
|
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_chars: int = 100
|
||||||
|
out_channels: int = 513
|
||||||
|
spec_segment_size: int = 32
|
||||||
|
hidden_channels: int = 192
|
||||||
|
hidden_channels_ffn_text_encoder: int = 768
|
||||||
|
num_heads_text_encoder: int = 2
|
||||||
|
num_layers_text_encoder: int = 6
|
||||||
|
kernel_size_text_encoder: int = 3
|
||||||
|
dropout_p_text_encoder: int = 0.1
|
||||||
|
dropout_p_duration_predictor: int = 0.1
|
||||||
|
kernel_size_posterior_encoder: int = 5
|
||||||
|
dilation_rate_posterior_encoder: int = 1
|
||||||
|
num_layers_posterior_encoder: int = 16
|
||||||
|
kernel_size_flow: int = 5
|
||||||
|
dilation_rate_flow: int = 1
|
||||||
|
num_layers_flow: int = 4
|
||||||
|
resblock_type_decoder: int = "1"
|
||||||
|
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||||
|
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||||
|
upsample_initial_channel_decoder: int = 512
|
||||||
|
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||||
|
use_sdp: int = True
|
||||||
|
noise_scale: float = 1.0
|
||||||
|
inference_noise_scale: float = 0.667
|
||||||
|
length_scale: int = 1
|
||||||
|
noise_scale_dp: float = 1.0
|
||||||
|
inference_noise_scale_dp: float = 0.8
|
||||||
|
max_inference_len: int = None
|
||||||
|
init_discriminator: bool = True
|
||||||
|
use_spectral_norm_disriminator: bool = False
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers_file: str = None
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
detach_dp_input: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Vits(BaseTTS):
|
||||||
|
"""VITS TTS model
|
||||||
|
|
||||||
|
Paper::
|
||||||
|
https://arxiv.org/pdf/2106.06103.pdf
|
||||||
|
|
||||||
|
Paper Abstract::
|
||||||
|
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
|
||||||
|
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
|
||||||
|
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
|
||||||
|
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
|
||||||
|
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
|
||||||
|
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
|
||||||
|
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
|
||||||
|
natural one-to-many relationship in which a text input can be spoken in multiple ways
|
||||||
|
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
|
||||||
|
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
|
||||||
|
available TTS systems and achieves a MOS comparable to ground truth.
|
||||||
|
|
||||||
|
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.tts.configs import VitsConfig
|
||||||
|
>>> from TTS.tts.models.vits import Vits
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
>>> model = Vits(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.END2END = True
|
||||||
|
|
||||||
|
if config.__class__.__name__ == "VitsConfig":
|
||||||
|
# loading from VitsConfig
|
||||||
|
if "num_chars" not in config:
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
config.model_args.num_chars = config.num_chars
|
||||||
|
args = self.config.model_args
|
||||||
|
elif isinstance(config, VitsArgs):
|
||||||
|
# loading from VitsArgs
|
||||||
|
self.config = config
|
||||||
|
args = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a VitsConfig or VitsArgs")
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
|
self.length_scale = args.length_scale
|
||||||
|
self.noise_scale = args.noise_scale
|
||||||
|
self.inference_noise_scale = args.inference_noise_scale
|
||||||
|
self.inference_noise_scale_dp = args.inference_noise_scale_dp
|
||||||
|
self.noise_scale_dp = args.noise_scale_dp
|
||||||
|
self.max_inference_len = args.max_inference_len
|
||||||
|
self.spec_segment_size = args.spec_segment_size
|
||||||
|
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
args.num_chars,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels_ffn_text_encoder,
|
||||||
|
args.num_heads_text_encoder,
|
||||||
|
args.num_layers_text_encoder,
|
||||||
|
args.kernel_size_text_encoder,
|
||||||
|
args.dropout_p_text_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
args.out_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
kernel_size=args.kernel_size_posterior_encoder,
|
||||||
|
dilation_rate=args.dilation_rate_posterior_encoder,
|
||||||
|
num_layers=args.num_layers_posterior_encoder,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flow = ResidualCouplingBlocks(
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
kernel_size=args.kernel_size_flow,
|
||||||
|
dilation_rate=args.dilation_rate_flow,
|
||||||
|
num_layers=args.num_layers_flow,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_sdp:
|
||||||
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
|
args.hidden_channels,
|
||||||
|
192,
|
||||||
|
3,
|
||||||
|
args.dropout_p_duration_predictor,
|
||||||
|
4,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.duration_predictor = DurationPredictor(
|
||||||
|
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.waveform_decoder = HifiganGenerator(
|
||||||
|
args.hidden_channels,
|
||||||
|
1,
|
||||||
|
args.resblock_type_decoder,
|
||||||
|
args.resblock_dilation_sizes_decoder,
|
||||||
|
args.resblock_kernel_sizes_decoder,
|
||||||
|
args.upsample_kernel_sizes_decoder,
|
||||||
|
args.upsample_initial_channel_decoder,
|
||||||
|
args.upsample_rates_decoder,
|
||||||
|
inference_padding=0,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
conv_pre_weight_norm=False,
|
||||||
|
conv_post_weight_norm=False,
|
||||||
|
conv_post_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.init_discriminator:
|
||||||
|
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
If you need a different behaviour, override this function for your model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
config = config.model_args
|
||||||
|
self.embedded_speaker_dim = 0
|
||||||
|
# init speaker manager
|
||||||
|
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||||
|
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
|
||||||
|
self.speaker_manager.num_speakers = config.num_speakers
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
# init speaker embedding layer
|
||||||
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||||
|
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
|
||||||
|
# init d-vector usage
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_cond_input(aux_input: Dict):
|
||||||
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
|
sid, g = None, None
|
||||||
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
|
sid = aux_input["speaker_ids"]
|
||||||
|
if sid.ndim == 0:
|
||||||
|
sid = sid.unsqueeze_(0)
|
||||||
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||||
|
g = aux_input["d_vectors"]
|
||||||
|
return sid, g
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.tensor,
|
||||||
|
x_lengths: torch.tensor,
|
||||||
|
y: torch.tensor,
|
||||||
|
y_lengths: torch.tensor,
|
||||||
|
aux_input={"d_vectors": None, "speaker_ids": None},
|
||||||
|
) -> Dict:
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.tensor): Batch of input character sequence IDs.
|
||||||
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||||
|
y (torch.tensor): Batch of input spectrograms.
|
||||||
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||||
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: model outputs keyed by the output name.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_seq]`
|
||||||
|
- x_lengths: :math:`[B]`
|
||||||
|
- y: :math:`[B, C, T_spec]`
|
||||||
|
- y_lengths: :math:`[B]`
|
||||||
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
|
- speaker_ids: :math:`[B]`
|
||||||
|
"""
|
||||||
|
outputs = {}
|
||||||
|
sid, g = self._set_cond_input(aux_input)
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
|
||||||
|
# speaker embedding
|
||||||
|
if self.num_speakers > 1 and sid is not None:
|
||||||
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# find the alignment path
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
|
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||||
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
|
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp = logp2 + logp3
|
||||||
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
|
# duration predictor
|
||||||
|
attn_durations = attn.sum(3)
|
||||||
|
if self.args.use_sdp:
|
||||||
|
nll_duration = self.duration_predictor(
|
||||||
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
x_mask,
|
||||||
|
attn_durations,
|
||||||
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
)
|
||||||
|
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
|
||||||
|
outputs["nll_duration"] = nll_duration
|
||||||
|
else:
|
||||||
|
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||||
|
log_durations = self.duration_predictor(
|
||||||
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
x_mask,
|
||||||
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
)
|
||||||
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
outputs["loss_duration"] = loss_duration
|
||||||
|
|
||||||
|
# expand prior
|
||||||
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
|
# select a random feature segment for the waveform decoder
|
||||||
|
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
|
||||||
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
outputs.update(
|
||||||
|
{
|
||||||
|
"model_outputs": o,
|
||||||
|
"alignments": attn.squeeze(1),
|
||||||
|
"slice_ids": slice_ids,
|
||||||
|
"z": z,
|
||||||
|
"z_p": z_p,
|
||||||
|
"m_p": m_p,
|
||||||
|
"logs_p": logs_p,
|
||||||
|
"m_q": m_q,
|
||||||
|
"logs_q": logs_q,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_seq]`
|
||||||
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
|
- speaker_ids: :math:`[B]`
|
||||||
|
"""
|
||||||
|
sid, g = self._set_cond_input(aux_input)
|
||||||
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
|
||||||
|
if self.num_speakers > 0 and sid:
|
||||||
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.args.use_sdp:
|
||||||
|
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
||||||
|
else:
|
||||||
|
logw = self.duration_predictor(x, x_mask, g=g)
|
||||||
|
|
||||||
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
|
w_ceil = torch.ceil(w)
|
||||||
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
|
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
|
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||||
|
|
||||||
|
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||||
|
"""TODO: create an end-point for voice conversion"""
|
||||||
|
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||||
|
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||||
|
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||||
|
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||||
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
|
return o_hat, y_mask, (z, z_p, z_hat)
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Input tensors.
|
||||||
|
criterion (nn.Module): Loss layer designed for the model.
|
||||||
|
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||||
|
"""
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
if optimizer_idx not in [0, 1]:
|
||||||
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
text_input = batch["text_input"]
|
||||||
|
text_lengths = batch["text_lengths"]
|
||||||
|
mel_lengths = batch["mel_lengths"]
|
||||||
|
linear_input = batch["linear_input"]
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
|
# generator pass
|
||||||
|
outputs = self.forward(
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
linear_input.transpose(1, 2),
|
||||||
|
mel_lengths,
|
||||||
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
# cache tensors for the discriminator
|
||||||
|
self.y_disc_cache = None
|
||||||
|
self.wav_seg_disc_cache = None
|
||||||
|
self.y_disc_cache = outputs["model_outputs"]
|
||||||
|
wav_seg = segment(
|
||||||
|
waveform.transpose(1, 2),
|
||||||
|
outputs["slice_ids"] * self.config.audio.hop_length,
|
||||||
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
|
)
|
||||||
|
self.wav_seg_disc_cache = wav_seg
|
||||||
|
outputs["waveform_seg"] = wav_seg
|
||||||
|
|
||||||
|
# compute discriminator scores and features
|
||||||
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
|
||||||
|
_, outputs["feats_disc_real"] = self.disc(wav_seg)
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
waveform_hat=outputs["model_outputs"].float(),
|
||||||
|
waveform=wav_seg.float(),
|
||||||
|
z_p=outputs["z_p"].float(),
|
||||||
|
logs_q=outputs["logs_q"].float(),
|
||||||
|
m_p=outputs["m_p"].float(),
|
||||||
|
logs_p=outputs["logs_p"].float(),
|
||||||
|
z_len=mel_lengths,
|
||||||
|
scores_disc_fake=outputs["scores_disc_fake"],
|
||||||
|
feats_disc_fake=outputs["feats_disc_fake"],
|
||||||
|
feats_disc_real=outputs["feats_disc_real"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle the duration loss
|
||||||
|
if self.args.use_sdp:
|
||||||
|
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||||
|
loss_dict["loss"] += outputs["nll_duration"]
|
||||||
|
else:
|
||||||
|
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||||
|
loss_dict["loss"] += outputs["nll_duration"]
|
||||||
|
|
||||||
|
elif optimizer_idx == 1:
|
||||||
|
# discriminator pass
|
||||||
|
outputs = {}
|
||||||
|
|
||||||
|
# compute scores and features
|
||||||
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
|
||||||
|
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
outputs["scores_disc_real"],
|
||||||
|
outputs["scores_disc_fake"],
|
||||||
|
)
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(
|
||||||
|
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
|
||||||
|
): # pylint: disable=no-self-use
|
||||||
|
"""Create visualizations and waveform examples.
|
||||||
|
|
||||||
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||||
|
be projected onto Tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ap (AudioProcessor): audio processor used at training.
|
||||||
|
batch (Dict): Model inputs used at the previous training step.
|
||||||
|
outputs (Dict): Model outputs generated at the previoud training step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
|
"""
|
||||||
|
y_hat = outputs[0]["model_outputs"]
|
||||||
|
y = outputs[0]["waveform_seg"]
|
||||||
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||||
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
audios = {f"{name_prefix}/audio": sample_voice}
|
||||||
|
|
||||||
|
alignments = outputs[0]["alignments"]
|
||||||
|
align_img = alignments[0].data.cpu().numpy().T
|
||||||
|
|
||||||
|
figures.update(
|
||||||
|
{
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return figures, audios
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||||
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs, "eval")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||||
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
|
"""
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self.get_aux_input()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
wav, alignment, _, _ = synthesis(
|
||||||
|
self,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
"cuda" in str(next(self.parameters()).device),
|
||||||
|
ap,
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
).values()
|
||||||
|
|
||||||
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||||
|
return test_figures, test_audios
|
||||||
|
|
||||||
|
def get_optimizer(self) -> List:
|
||||||
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||||
|
|
||||||
|
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: optimizers.
|
||||||
|
"""
|
||||||
|
self.disc.requires_grad_(False)
|
||||||
|
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
self.disc.requires_grad_(True)
|
||||||
|
optimizer1 = get_optimizer(
|
||||||
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
|
)
|
||||||
|
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||||
|
return [optimizer1, optimizer2]
|
||||||
|
|
||||||
|
def get_lr(self) -> List:
|
||||||
|
"""Set the initial learning rates for each optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: learning rates for each optimizer.
|
||||||
|
"""
|
||||||
|
return [self.config.lr_gen, self.config.lr_disc]
|
||||||
|
|
||||||
|
def get_scheduler(self, optimizer) -> List:
|
||||||
|
"""Set the schedulers for each optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: Schedulers, one for each optimizer.
|
||||||
|
"""
|
||||||
|
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||||
|
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||||
|
return [scheduler1, scheduler2]
|
||||||
|
|
||||||
|
def get_criterion(self):
|
||||||
|
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
||||||
|
`train_step()`"""
|
||||||
|
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
|
||||||
|
VitsDiscriminatorLoss,
|
||||||
|
VitsGeneratorLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_symbols(config):
|
||||||
|
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
|
||||||
|
whole training and inference steps."""
|
||||||
|
_pad = config.characters["pad"]
|
||||||
|
_punctuations = config.characters["punctuations"]
|
||||||
|
_letters = config.characters["characters"]
|
||||||
|
_letters_ipa = config.characters["phonemes"]
|
||||||
|
symbols = [_pad] + list(_punctuations) + list(_letters)
|
||||||
|
if config.use_phonemes:
|
||||||
|
symbols += list(_letters_ipa)
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_characters(config: Coqpit):
|
||||||
|
if config.characters is not None:
|
||||||
|
symbols = Vits.make_symbols(config)
|
||||||
|
else:
|
||||||
|
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||||
|
parse_symbols,
|
||||||
|
phonemes,
|
||||||
|
symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.characters = parse_symbols()
|
||||||
|
if config.use_phonemes:
|
||||||
|
symbols = phonemes
|
||||||
|
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||||
|
return symbols, config, num_chars
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
|
@ -2,6 +2,7 @@ import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
||||||
"r": r,
|
"r": r,
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
||||||
"r": r,
|
"r": r,
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
|
||||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
# same model binary if outputpath is provided
|
# same model binary if outputpath is provided
|
||||||
with open(output_path, "wb") as f:
|
with fsspec.open(output_path, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
return None
|
return None
|
||||||
return tflite_model
|
return tflite_model
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -84,12 +85,12 @@ class SpeakerManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_json(json_file_path: str) -> Dict:
|
def _load_json(json_file_path: str) -> Dict:
|
||||||
with open(json_file_path) as f:
|
with fsspec.open(json_file_path, "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_json(json_file_path: str, data: dict) -> None:
|
def _save_json(json_file_path: str, data: dict) -> None:
|
||||||
with open(json_file_path, "w") as f:
|
with fsspec.open(json_file_path, "w") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -294,9 +295,10 @@ def _set_file_path(path):
|
||||||
Intended to band aid the different paths returned in restored and continued training."""
|
Intended to band aid the different paths returned in restored and continued training."""
|
||||||
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||||
path_continue = os.path.join(path, "speakers.json")
|
path_continue = os.path.join(path, "speakers.json")
|
||||||
if os.path.exists(path_restore):
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
if fs.exists(path_restore):
|
||||||
return path_restore
|
return path_restore
|
||||||
if os.path.exists(path_continue):
|
if fs.exists(path_continue):
|
||||||
return path_continue
|
return path_continue
|
||||||
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||||
|
|
||||||
|
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
|
||||||
json_file = out_path
|
json_file = out_path
|
||||||
else:
|
else:
|
||||||
json_file = _set_file_path(out_path)
|
json_file = _set_file_path(out_path)
|
||||||
with open(json_file) as f:
|
with fsspec.open(json_file, "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
"""Saves speaker mapping if not yet present."""
|
"""Saves speaker mapping if not yet present."""
|
||||||
if out_path is not None:
|
if out_path is not None:
|
||||||
speakers_json_path = _set_file_path(out_path)
|
speakers_json_path = _set_file_path(out_path)
|
||||||
with open(speakers_json_path, "w") as f:
|
with fsspec.open(speakers_json_path, "w") as f:
|
||||||
json.dump(speaker_mapping, f, indent=4)
|
json.dump(speaker_mapping, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
@ -358,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
elif c.use_d_vector_file and c.d_vector_file:
|
elif c.use_d_vector_file and c.d_vector_file:
|
||||||
# new speaker manager with external speaker embeddings.
|
# new speaker manager with external speaker embeddings.
|
||||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
||||||
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file.
|
elif c.use_d_vector_file and not c.d_vector_file:
|
||||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
||||||
|
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||||
|
# new speaker manager with speaker IDs file.
|
||||||
|
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||||
print(
|
print(
|
||||||
" > Training with {} speakers: {}".format(
|
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def text_to_seq(text, CONFIG):
|
def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
# text ot phonemes to sequence vector
|
# text ot phonemes to sequence vector
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
|
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
|
||||||
tp=CONFIG.characters,
|
tp=CONFIG.characters,
|
||||||
add_blank=CONFIG.add_blank,
|
add_blank=CONFIG.add_blank,
|
||||||
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
||||||
|
custom_symbols=custom_symbols,
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
seq = np.asarray(
|
seq = np.asarray(
|
||||||
text_to_sequence(
|
text_to_sequence(
|
||||||
text,
|
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
|
||||||
text_cleaner,
|
|
||||||
tp=CONFIG.characters,
|
|
||||||
add_blank=CONFIG.add_blank,
|
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
@ -229,13 +227,16 @@ def synthesis(
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
custom_symbols = None
|
||||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||||
if isinstance(style_wav, dict):
|
if isinstance(style_wav, dict):
|
||||||
style_mel = style_wav
|
style_mel = style_wav
|
||||||
else:
|
else:
|
||||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||||
|
if hasattr(model, "make_symbols"):
|
||||||
|
custom_symbols = model.make_symbols(CONFIG)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
text_inputs = text_to_seq(text, CONFIG)
|
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
|
@ -274,15 +275,18 @@ def synthesis(
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
if use_griffin_lim:
|
if hasattr(model, "END2END") and model.END2END:
|
||||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
wav = model_outputs.squeeze(0)
|
||||||
# trim silence
|
else:
|
||||||
if do_trim_silence:
|
if use_griffin_lim:
|
||||||
wav = trim_silence(wav, ap)
|
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||||
|
# trim silence
|
||||||
|
if do_trim_silence:
|
||||||
|
wav = trim_silence(wav, ap)
|
||||||
return_dict = {
|
return_dict = {
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"alignments": alignments,
|
"alignments": alignments,
|
||||||
"model_outputs": model_outputs,
|
|
||||||
"text_inputs": text_inputs,
|
"text_inputs": text_inputs,
|
||||||
|
"outputs": outputs,
|
||||||
}
|
}
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
# adapted from https://github.com/keithito/tacotron
|
# adapted from https://github.com/keithito/tacotron
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
from typing import Dict, List
|
||||||
|
|
||||||
import gruut
|
import gruut
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from TTS.tts.utils.text import cleaners
|
from TTS.tts.utils.text import cleaners
|
||||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||||
|
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||||
|
|
||||||
_symbols = symbols
|
_symbols = symbols
|
||||||
_phonemes = phonemes
|
_phonemes = phonemes
|
||||||
|
|
||||||
# Regular expression matching text enclosed in curly braces:
|
# Regular expression matching text enclosed in curly braces:
|
||||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||||
|
|
||||||
|
@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
||||||
# Fix a few phonemes
|
# Fix a few phonemes
|
||||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||||
|
|
||||||
print(" > Phonemes: {}".format(ph))
|
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||||
|
@ -106,13 +105,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||||
|
|
||||||
|
|
||||||
def phoneme_to_sequence(
|
def phoneme_to_sequence(
|
||||||
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
|
text: str,
|
||||||
):
|
cleaner_names: List[str],
|
||||||
|
language: str,
|
||||||
|
enable_eos_bos: bool = False,
|
||||||
|
custom_symbols: List[str] = None,
|
||||||
|
tp: Dict = None,
|
||||||
|
add_blank: bool = False,
|
||||||
|
use_espeak_phonemes: bool = False,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Converts a string of phonemes to a sequence of IDs.
|
||||||
|
If `custom_symbols` is provided, it will override the default symbols.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): string to convert to a sequence
|
||||||
|
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||||
|
language (str): text language key for phonemization.
|
||||||
|
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
|
||||||
|
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||||
|
add_blank (bool): option to add a blank token between each token.
|
||||||
|
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: List of integers corresponding to the symbols in the text
|
||||||
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _phonemes_to_id, _phonemes
|
global _phonemes_to_id, _phonemes
|
||||||
if tp:
|
|
||||||
|
if custom_symbols is not None:
|
||||||
|
_phonemes = custom_symbols
|
||||||
|
elif tp:
|
||||||
_, _phonemes = make_symbols(**tp)
|
_, _phonemes = make_symbols(**tp)
|
||||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
|
@ -127,20 +151,22 @@ def phoneme_to_sequence(
|
||||||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||||
|
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
"""Converts a sequence of IDs back to a string"""
|
"""Converts a sequence of IDs back to a string"""
|
||||||
global _id_to_phonemes, _phonemes
|
global _id_to_phonemes, _phonemes
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
||||||
result = ""
|
result = ""
|
||||||
if tp:
|
|
||||||
|
if custom_symbols is not None:
|
||||||
|
_phonemes = custom_symbols
|
||||||
|
elif tp:
|
||||||
_, _phonemes = make_symbols(**tp)
|
_, _phonemes = make_symbols(**tp)
|
||||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||||
|
|
||||||
for symbol_id in sequence:
|
for symbol_id in sequence:
|
||||||
if symbol_id in _id_to_phonemes:
|
if symbol_id in _id_to_phonemes:
|
||||||
|
@ -149,27 +175,32 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
||||||
return result.replace("}{", " ")
|
return result.replace("}{", " ")
|
||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
def text_to_sequence(
|
||||||
|
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||||
|
) -> List[int]:
|
||||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
|
If `custom_symbols` is provided, it will override the default symbols.
|
||||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
|
||||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: string to convert to a sequence
|
text (str): string to convert to a sequence
|
||||||
cleaner_names: names of the cleaner functions to run the text through
|
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||||
tp: dictionary of character parameters to use a custom character set.
|
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||||
|
add_blank (bool): option to add a blank token between each token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of integers corresponding to the symbols in the text
|
List[int]: List of integers corresponding to the symbols in the text
|
||||||
"""
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _symbol_to_id, _symbols
|
global _symbol_to_id, _symbols
|
||||||
if tp:
|
|
||||||
|
if custom_symbols is not None:
|
||||||
|
_symbols = custom_symbols
|
||||||
|
elif tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
|
|
||||||
# Check for curly braces and treat their contents as ARPAbet:
|
# Check for curly braces and treat their contents as ARPAbet:
|
||||||
while text:
|
while text:
|
||||||
m = _CURLY_RE.match(text)
|
m = _CURLY_RE.match(text)
|
||||||
|
@ -185,16 +216,18 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_text(sequence, tp=None, add_blank=False):
|
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
|
||||||
"""Converts a sequence of IDs back to a string"""
|
"""Converts a sequence of IDs back to a string"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _id_to_symbol, _symbols
|
global _id_to_symbol, _symbols
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
||||||
|
|
||||||
if tp:
|
if custom_symbols is not None:
|
||||||
|
_symbols = custom_symbols
|
||||||
|
elif tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||||
|
|
||||||
result = ""
|
result = ""
|
||||||
for symbol_id in sequence:
|
for symbol_id in sequence:
|
||||||
|
|
|
@ -28,10 +28,10 @@ def make_symbols(
|
||||||
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
|
||||||
) # this is to keep previous models compatible.
|
) # this is to keep previous models compatible.
|
||||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||||
_arpabet = ["@" + s for s in _phonemes_sorted]
|
# _arpabet = ["@" + s for s in _phonemes_sorted]
|
||||||
# Export all symbols:
|
# Export all symbols:
|
||||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||||
_symbols += _arpabet
|
# _symbols += _arpabet
|
||||||
return _symbols, _phonemes
|
return _symbols, _phonemes
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""TODO: Merge this with audio.py"""
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
|
TODO: Merge this with audio.py
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
use_mel=False,
|
use_mel=False,
|
||||||
|
do_amp_to_db=False,
|
||||||
|
spec_gain=1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
self.mel_fmax = mel_fmax
|
self.mel_fmax = mel_fmax
|
||||||
self.n_mels = n_mels
|
self.n_mels = n_mels
|
||||||
self.use_mel = use_mel
|
self.use_mel = use_mel
|
||||||
|
self.do_amp_to_db = do_amp_to_db
|
||||||
|
self.spec_gain = spec_gain
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||||
self.mel_basis = None
|
self.mel_basis = None
|
||||||
if use_mel:
|
if use_mel:
|
||||||
|
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||||
if self.use_mel:
|
if self.use_mel:
|
||||||
S = torch.matmul(self.mel_basis.to(x), S)
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
|
if self.do_amp_to_db:
|
||||||
|
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||||
return S
|
return S
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self):
|
||||||
|
@ -87,6 +96,14 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _amp_to_db(x, spec_gain=1.0):
|
||||||
|
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _db_to_amp(x, spec_gain=1.0):
|
||||||
|
return torch.exp(x) / spec_gain
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
@ -97,33 +114,93 @@ class AudioProcessor(object):
|
||||||
of the class with the model config. They are not meaningful for all the arguments.
|
of the class with the model config. They are not meaningful for all the arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_rate (int, optional): target audio sampling rate. Defaults to None.
|
sample_rate (int, optional):
|
||||||
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
target audio sampling rate. Defaults to None.
|
||||||
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
|
|
||||||
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
|
resample (bool, optional):
|
||||||
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
|
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||||
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
|
|
||||||
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
|
num_mels (int, optional):
|
||||||
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
number of melspectrogram dimensions. Defaults to None.
|
||||||
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
|
||||||
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
log_func (int, optional):
|
||||||
fft_size (int, optional): FFT window size for STFT. Defaults to 1024.
|
log exponent used for converting spectrogram aplitude to DB.
|
||||||
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
|
||||||
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
min_level_db (int, optional):
|
||||||
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None.
|
minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||||
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
|
||||||
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None.
|
frame_shift_ms (int, optional):
|
||||||
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None.
|
milliseconds of frames between STFT columns. Defaults to None.
|
||||||
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
|
|
||||||
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20.
|
frame_length_ms (int, optional):
|
||||||
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'.
|
milliseconds of STFT window length. Defaults to None.
|
||||||
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
|
||||||
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None.
|
hop_length (int, optional):
|
||||||
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False.
|
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||||
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
|
|
||||||
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False.
|
win_length (int, optional):
|
||||||
stats_path (str, optional): Path to the computed stats file. Defaults to None.
|
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||||
verbose (bool, optional): enable/disable logging. Defaults to True.
|
|
||||||
|
ref_level_db (int, optional):
|
||||||
|
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||||
|
|
||||||
|
fft_size (int, optional):
|
||||||
|
FFT window size for STFT. Defaults to 1024.
|
||||||
|
|
||||||
|
power (int, optional):
|
||||||
|
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||||
|
|
||||||
|
preemphasis (float, optional):
|
||||||
|
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||||
|
|
||||||
|
signal_norm (bool, optional):
|
||||||
|
enable/disable signal normalization. Defaults to None.
|
||||||
|
|
||||||
|
symmetric_norm (bool, optional):
|
||||||
|
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||||
|
|
||||||
|
max_norm (float, optional):
|
||||||
|
```k``` defining the normalization range. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmin (int, optional):
|
||||||
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmax (int, optional):
|
||||||
|
maximum filter frequency for computing melspectrograms.. Defaults to None.
|
||||||
|
|
||||||
|
spec_gain (int, optional):
|
||||||
|
gain applied when converting amplitude to DB. Defaults to 20.
|
||||||
|
|
||||||
|
stft_pad_mode (str, optional):
|
||||||
|
Padding mode for STFT. Defaults to 'reflect'.
|
||||||
|
|
||||||
|
clip_norm (bool, optional):
|
||||||
|
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||||
|
|
||||||
|
griffin_lim_iters (int, optional):
|
||||||
|
Number of GriffinLim iterations. Defaults to None.
|
||||||
|
|
||||||
|
do_trim_silence (bool, optional):
|
||||||
|
enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||||
|
|
||||||
|
trim_db (int, optional):
|
||||||
|
DB threshold used for silence trimming. Defaults to 60.
|
||||||
|
|
||||||
|
do_sound_norm (bool, optional):
|
||||||
|
enable/disable signal normalization. Defaults to False.
|
||||||
|
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
do_amp_to_db_mel (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
stats_path (str, optional):
|
||||||
|
Path to the computed stats file. Defaults to None.
|
||||||
|
|
||||||
|
verbose (bool, optional):
|
||||||
|
enable/disable logging. Defaults to True.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -153,6 +230,8 @@ class AudioProcessor(object):
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
trim_db=60,
|
trim_db=60,
|
||||||
do_sound_norm=False,
|
do_sound_norm=False,
|
||||||
|
do_amp_to_db_linear=True,
|
||||||
|
do_amp_to_db_mel=True,
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**_,
|
**_,
|
||||||
|
@ -181,6 +260,8 @@ class AudioProcessor(object):
|
||||||
self.do_trim_silence = do_trim_silence
|
self.do_trim_silence = do_trim_silence
|
||||||
self.trim_db = trim_db
|
self.trim_db = trim_db
|
||||||
self.do_sound_norm = do_sound_norm
|
self.do_sound_norm = do_sound_norm
|
||||||
|
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||||
|
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||||
self.stats_path = stats_path
|
self.stats_path = stats_path
|
||||||
# setup exp_func for db to amp conversion
|
# setup exp_func for db to amp conversion
|
||||||
if log_func == "np.log":
|
if log_func == "np.log":
|
||||||
|
@ -381,7 +462,6 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Decibels spectrogram.
|
np.ndarray: Decibels spectrogram.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
@ -448,7 +528,10 @@ class AudioProcessor(object):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
else:
|
else:
|
||||||
D = self._stft(y)
|
D = self._stft(y)
|
||||||
S = self._amp_to_db(np.abs(D))
|
if self.do_amp_to_db_linear:
|
||||||
|
S = self._amp_to_db(np.abs(D))
|
||||||
|
else:
|
||||||
|
S = np.abs(D)
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
|
@ -457,7 +540,10 @@ class AudioProcessor(object):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
else:
|
else:
|
||||||
D = self._stft(y)
|
D = self._stft(y)
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
if self.do_amp_to_db_mel:
|
||||||
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||||
|
else:
|
||||||
|
S = self._linear_to_mel(np.abs(D))
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,23 +57,22 @@ def get_commit_hash():
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path, model_name):
|
def get_experiment_folder_path(root_path, model_name):
|
||||||
"""Create a folder with the current date and time"""
|
"""Get an experiment folder path with the current date and time"""
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
print(" > Experiment folder: {}".format(output_folder))
|
||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
|
|
||||||
def remove_experiment_folder(experiment_path):
|
def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
fs = fsspec.get_mapper(experiment_path).fs
|
||||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
|
||||||
if not checkpoint_files:
|
if not checkpoint_files:
|
||||||
if os.path.exists(experiment_path):
|
if fs.exists(experiment_path):
|
||||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
fs.rm(experiment_path, recursive=True)
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
else:
|
else:
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
from shutil import copyfile
|
import shutil
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
@ -24,7 +26,7 @@ class AttrDict(dict):
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(config, out_path, new_fields):
|
def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||||
"""Copy config.json and other model files to training folder and add
|
"""Copy config.json and other model files to training folder and add
|
||||||
new fields.
|
new fields.
|
||||||
|
|
||||||
|
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
|
||||||
copy_config_path = os.path.join(out_path, "config.json")
|
copy_config_path = os.path.join(out_path, "config.json")
|
||||||
# add extra information fields
|
# add extra information fields
|
||||||
config.update(new_fields, allow_new=True)
|
config.update(new_fields, allow_new=True)
|
||||||
config.save_json(copy_config_path)
|
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
||||||
|
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
||||||
|
json.dump(config.to_dict(), f, indent=4)
|
||||||
|
|
||||||
# copy model stats file if available
|
# copy model stats file if available
|
||||||
if config.audio.stats_path is not None:
|
if config.audio.stats_path is not None:
|
||||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||||
if not os.path.exists(copy_stats_path):
|
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
||||||
copyfile(
|
if not filesystem.exists(copy_stats_path):
|
||||||
config.audio.stats_path,
|
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
||||||
copy_stats_path,
|
with fsspec.open(copy_stats_path, "wb") as target_file:
|
||||||
)
|
shutil.copyfileobj(source_file, target_file)
|
||||||
|
|
||||||
|
|
||||||
|
def load_fsspec(path: str, **kwargs) -> Any:
|
||||||
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Any path or url supported by fsspec.
|
||||||
|
**kwargs: Keyword arguments forwarded to torch.load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object stored in path.
|
||||||
|
"""
|
||||||
|
with fsspec.open(path, "rb") as f:
|
||||||
|
return torch.load(f, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||||
try:
|
try:
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pickle_tts.Unpickler = RenamingUnpickler
|
pickle_tts.Unpickler = RenamingUnpickler
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||||
model.load_state_dict(state["model"])
|
model.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
|
def save_fsspec(state: Any, path: str, **kwargs):
|
||||||
|
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State object to save
|
||||||
|
path: Any path or url supported by fsspec.
|
||||||
|
**kwargs: Keyword arguments forwarded to torch.save.
|
||||||
|
"""
|
||||||
|
with fsspec.open(path, "wb") as f:
|
||||||
|
torch.save(state, f, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
model_state = model.module.state_dict()
|
model_state = model.module.state_dict()
|
||||||
|
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
torch.save(state, output_path)
|
save_fsspec(state, output_path)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
|
@ -147,18 +178,16 @@ def save_best_model(
|
||||||
model_loss=current_loss,
|
model_loss=current_loss,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
fs = fsspec.get_mapper(out_path).fs
|
||||||
# only delete previous if current is saved successfully
|
# only delete previous if current is saved successfully
|
||||||
if not keep_all_best or (current_step < keep_after):
|
if not keep_all_best or (current_step < keep_after):
|
||||||
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||||
for model_name in model_names:
|
for model_name in model_names:
|
||||||
if os.path.basename(model_name) == best_model_name:
|
if os.path.basename(model_name) != best_model_name:
|
||||||
continue
|
fs.rm(model_name)
|
||||||
os.remove(model_name)
|
# create a shortcut which always points to the currently best model
|
||||||
# create symlink to best model for convinience
|
shortcut_name = "best_model.pth.tar"
|
||||||
link_name = "best_model.pth.tar"
|
shortcut_path = os.path.join(out_path, shortcut_name)
|
||||||
link_path = os.path.join(out_path, link_name)
|
fs.copy(checkpoint_path, shortcut_path)
|
||||||
if os.path.islink(link_path) or os.path.isfile(link_path):
|
|
||||||
os.remove(link_path)
|
|
||||||
os.symlink(best_model_name, os.path.join(out_path, link_name))
|
|
||||||
best_loss = current_loss
|
best_loss = current_loss
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
|
@ -1,2 +1,24 @@
|
||||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(config):
|
||||||
|
if config.dashboard_logger == "tensorboard":
|
||||||
|
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
||||||
|
|
||||||
|
elif config.dashboard_logger == "wandb":
|
||||||
|
project_name = config.model
|
||||||
|
if config.project_name:
|
||||||
|
project_name = config.project_name
|
||||||
|
|
||||||
|
dashboard_logger = WandbLogger(
|
||||||
|
project=project_name,
|
||||||
|
name=config.run_name,
|
||||||
|
config=config,
|
||||||
|
entity=config.wandb_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
|
return dashboard_logger
|
||||||
|
|
|
@ -38,7 +38,7 @@ class ConsoleLogger:
|
||||||
def print_train_start(self):
|
def print_train_start(self):
|
||||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||||
|
|
||||||
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict):
|
def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict):
|
||||||
indent = " | > "
|
indent = " | > "
|
||||||
print()
|
print()
|
||||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
||||||
|
@ -50,13 +50,6 @@ class ConsoleLogger:
|
||||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
|
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
|
||||||
else:
|
else:
|
||||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||||
for idx, (key, value) in enumerate(log_dict.items()):
|
|
||||||
if isinstance(value, list):
|
|
||||||
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
|
|
||||||
else:
|
|
||||||
log_text += f"{indent}{key}: {value}"
|
|
||||||
if idx < len(log_dict) - 1:
|
|
||||||
log_text += "\n"
|
|
||||||
print(log_text, flush=True)
|
print(log_text, flush=True)
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
|
|
@ -7,10 +7,8 @@ class TensorboardLogger(object):
|
||||||
def __init__(self, log_dir, model_name):
|
def __init__(self, log_dir, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.writer = SummaryWriter(log_dir)
|
self.writer = SummaryWriter(log_dir)
|
||||||
self.train_stats = {}
|
|
||||||
self.eval_stats = {}
|
|
||||||
|
|
||||||
def tb_model_weights(self, model, step):
|
def model_weights(self, model, step):
|
||||||
layer_num = 1
|
layer_num = 1
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.numel() == 1:
|
if param.numel() == 1:
|
||||||
|
@ -41,32 +39,41 @@ class TensorboardLogger(object):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def tb_train_step_stats(self, step, stats):
|
def train_step_stats(self, step, stats):
|
||||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||||
|
|
||||||
def tb_train_epoch_stats(self, step, stats):
|
def train_epoch_stats(self, step, stats):
|
||||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
|
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
|
||||||
|
|
||||||
def tb_train_figures(self, step, figures):
|
def train_figures(self, step, figures):
|
||||||
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
|
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
|
||||||
|
|
||||||
def tb_train_audios(self, step, audios, sample_rate):
|
def train_audios(self, step, audios, sample_rate):
|
||||||
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
|
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
|
||||||
|
|
||||||
def tb_eval_stats(self, step, stats):
|
def eval_stats(self, step, stats):
|
||||||
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
|
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
|
||||||
|
|
||||||
def tb_eval_figures(self, step, figures):
|
def eval_figures(self, step, figures):
|
||||||
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
|
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
|
||||||
|
|
||||||
def tb_eval_audios(self, step, audios, sample_rate):
|
def eval_audios(self, step, audios, sample_rate):
|
||||||
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
|
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
|
||||||
|
|
||||||
def tb_test_audios(self, step, audios, sample_rate):
|
def test_audios(self, step, audios, sample_rate):
|
||||||
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
|
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
|
||||||
|
|
||||||
def tb_test_figures(self, step, figures):
|
def test_figures(self, step, figures):
|
||||||
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
|
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
|
||||||
|
|
||||||
def tb_add_text(self, title, text, step):
|
def add_text(self, title, text, step):
|
||||||
self.writer.add_text(title, text, step)
|
self.writer.add_text(title, text, step)
|
||||||
|
|
||||||
|
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
|
||||||
|
yield
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.writer.close()
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
# pylint: disable=W0613
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
from wandb import finish, init # pylint: disable=W0611
|
||||||
|
except ImportError:
|
||||||
|
wandb = None
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
|
||||||
|
if not wandb:
|
||||||
|
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
|
||||||
|
|
||||||
|
self.run = None
|
||||||
|
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||||
|
self.model_name = self.run.config.model
|
||||||
|
self.log_dict = {}
|
||||||
|
|
||||||
|
def model_weights(self, model):
|
||||||
|
layer_num = 1
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.numel() == 1:
|
||||||
|
self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()})
|
||||||
|
else:
|
||||||
|
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
|
||||||
|
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
||||||
|
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
||||||
|
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
||||||
|
self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
|
||||||
|
self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
|
||||||
|
layer_num += 1
|
||||||
|
|
||||||
|
def dict_to_scalar(self, scope_name, stats):
|
||||||
|
for key, value in stats.items():
|
||||||
|
self.log_dict["{}/{}".format(scope_name, key)] = value
|
||||||
|
|
||||||
|
def dict_to_figure(self, scope_name, figures):
|
||||||
|
for key, value in figures.items():
|
||||||
|
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
|
||||||
|
|
||||||
|
def dict_to_audios(self, scope_name, audios, sample_rate):
|
||||||
|
for key, value in audios.items():
|
||||||
|
if value.dtype == "float16":
|
||||||
|
value = value.astype("float32")
|
||||||
|
try:
|
||||||
|
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
|
||||||
|
except RuntimeError:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def log(self, log_dict, prefix="", flush=False):
|
||||||
|
for key, value in log_dict.items():
|
||||||
|
self.log_dict[prefix + key] = value
|
||||||
|
if flush: # for cases where you don't want to accumulate data
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def train_step_stats(self, step, stats):
|
||||||
|
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
|
||||||
|
|
||||||
|
def train_epoch_stats(self, step, stats):
|
||||||
|
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
|
||||||
|
|
||||||
|
def train_figures(self, step, figures):
|
||||||
|
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
|
||||||
|
|
||||||
|
def train_audios(self, step, audios, sample_rate):
|
||||||
|
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
|
||||||
|
|
||||||
|
def eval_stats(self, step, stats):
|
||||||
|
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
|
||||||
|
|
||||||
|
def eval_figures(self, step, figures):
|
||||||
|
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
|
||||||
|
|
||||||
|
def eval_audios(self, step, audios, sample_rate):
|
||||||
|
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
|
||||||
|
|
||||||
|
def test_audios(self, step, audios, sample_rate):
|
||||||
|
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
|
||||||
|
|
||||||
|
def test_figures(self, step, figures):
|
||||||
|
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
|
||||||
|
|
||||||
|
def add_text(self, title, text, step):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
if self.run:
|
||||||
|
wandb.log(self.log_dict)
|
||||||
|
self.log_dict = {}
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
if self.run:
|
||||||
|
self.run.finish()
|
||||||
|
|
||||||
|
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
||||||
|
if not self.run:
|
||||||
|
return
|
||||||
|
name = "_".join([self.run.id, name])
|
||||||
|
artifact = wandb.Artifact(name, type=artifact_type)
|
||||||
|
data_path = Path(file_or_dir)
|
||||||
|
if data_path.is_dir():
|
||||||
|
artifact.add_dir(str(data_path))
|
||||||
|
elif data_path.is_file():
|
||||||
|
artifact.add_file(str(data_path))
|
||||||
|
|
||||||
|
self.run.log_artifact(artifact, aliases=aliases)
|
|
@ -64,6 +64,7 @@ class ModelManager(object):
|
||||||
def list_models(self):
|
def list_models(self):
|
||||||
print(" Name format: type/language/dataset/model")
|
print(" Name format: type/language/dataset/model")
|
||||||
models_name_list = []
|
models_name_list = []
|
||||||
|
model_count = 1
|
||||||
for model_type in self.models_dict:
|
for model_type in self.models_dict:
|
||||||
for lang in self.models_dict[model_type]:
|
for lang in self.models_dict[model_type]:
|
||||||
for dataset in self.models_dict[model_type][lang]:
|
for dataset in self.models_dict[model_type][lang]:
|
||||||
|
@ -71,10 +72,11 @@ class ModelManager(object):
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||||
else:
|
else:
|
||||||
print(f" >: {model_type}/{lang}/{dataset}/{model}")
|
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||||
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||||
|
model_count += 1
|
||||||
return models_name_list
|
return models_name_list
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name):
|
||||||
|
|
|
@ -12,7 +12,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||||
from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
|
@ -103,6 +102,34 @@ class Synthesizer(object):
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
||||||
|
|
||||||
|
def _set_tts_speaker_file(self):
|
||||||
|
"""Set the TTS speaker file used by a multi-speaker model."""
|
||||||
|
# setup if multi-speaker settings are in the global model config
|
||||||
|
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
||||||
|
if self.tts_config.use_d_vector_file:
|
||||||
|
self.tts_speakers_file = (
|
||||||
|
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
||||||
|
)
|
||||||
|
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
||||||
|
else:
|
||||||
|
self.tts_speakers_file = (
|
||||||
|
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup if multi-speaker settings are in the model args config
|
||||||
|
if (
|
||||||
|
self.tts_speakers_file is None
|
||||||
|
and hasattr(self.tts_config, "model_args")
|
||||||
|
and hasattr(self.tts_config.model_args, "use_speaker_embedding")
|
||||||
|
and self.tts_config.model_args.use_speaker_embedding
|
||||||
|
):
|
||||||
|
_args = self.tts_config.model_args
|
||||||
|
if _args.use_d_vector_file:
|
||||||
|
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"]
|
||||||
|
_args["d_vector_file"] = self.tts_speakers_file
|
||||||
|
else:
|
||||||
|
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"]
|
||||||
|
|
||||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model.
|
"""Load the TTS model.
|
||||||
|
|
||||||
|
@ -113,29 +140,15 @@ class Synthesizer(object):
|
||||||
"""
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
|
|
||||||
global symbols, phonemes
|
|
||||||
self.tts_config = load_config(tts_config_path)
|
self.tts_config = load_config(tts_config_path)
|
||||||
self.use_phonemes = self.tts_config.use_phonemes
|
self.use_phonemes = self.tts_config.use_phonemes
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||||
|
|
||||||
if self.tts_config.has("characters") and self.tts_config.characters:
|
|
||||||
symbols, phonemes = make_symbols(**self.tts_config.characters)
|
|
||||||
|
|
||||||
if self.use_phonemes:
|
|
||||||
self.input_size = len(phonemes)
|
|
||||||
else:
|
|
||||||
self.input_size = len(symbols)
|
|
||||||
|
|
||||||
if self.tts_config.use_speaker_embedding is True:
|
|
||||||
self.tts_speakers_file = (
|
|
||||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
|
||||||
)
|
|
||||||
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
|
||||||
|
|
||||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
self._set_tts_speaker_file()
|
||||||
|
|
||||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
"""Load the vocoder model.
|
"""Load the vocoder model.
|
||||||
|
@ -187,15 +200,22 @@ class Synthesizer(object):
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
wavs = []
|
wavs = []
|
||||||
speaker_embedding = None
|
|
||||||
sens = self.split_into_sentences(text)
|
sens = self.split_into_sentences(text)
|
||||||
print(" > Text splitted to sentences.")
|
print(" > Text splitted to sentences.")
|
||||||
print(sens)
|
print(sens)
|
||||||
|
|
||||||
|
# handle multi-speaker
|
||||||
|
speaker_embedding = None
|
||||||
|
speaker_id = None
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
|
||||||
if speaker_idx and isinstance(speaker_idx, str):
|
if speaker_idx and isinstance(speaker_idx, str):
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
if self.tts_config.use_d_vector_file:
|
||||||
|
# get the speaker embedding from the saved d_vectors.
|
||||||
|
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||||
|
else:
|
||||||
|
# get speaker idx from the speaker name
|
||||||
|
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
||||||
|
|
||||||
elif not speaker_idx and not speaker_wav:
|
elif not speaker_idx and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-speaker model. "
|
" [!] Look like you use a multi-speaker model. "
|
||||||
|
@ -224,14 +244,14 @@ class Synthesizer(object):
|
||||||
CONFIG=self.tts_config,
|
CONFIG=self.tts_config,
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
speaker_id=None,
|
speaker_id=speaker_id,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
)
|
)
|
||||||
waveform = outputs["wav"]
|
waveform = outputs["wav"]
|
||||||
mel_postnet_spec = outputs["model_outputs"]
|
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def get_scheduler(
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
|
||||||
) -> torch.optim.Optimizer:
|
) -> torch.optim.Optimizer:
|
||||||
"""Find, initialize and return a optimizer.
|
"""Find, initialize and return a optimizer.
|
||||||
|
|
||||||
|
@ -66,4 +66,6 @@ def get_optimizer(
|
||||||
optimizer = getattr(module, "RAdam")
|
optimizer = getattr(module, "RAdam")
|
||||||
else:
|
else:
|
||||||
optimizer = getattr(torch.optim, optimizer_name)
|
optimizer = getattr(torch.optim, optimizer_name)
|
||||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
if model is not None:
|
||||||
|
parameters = model.parameters()
|
||||||
|
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||||
|
|
|
@ -24,8 +24,10 @@ def setup_model(config: Coqpit):
|
||||||
elif config.model.lower() == "wavegrad":
|
elif config.model.lower() == "wavegrad":
|
||||||
MyModel = getattr(MyModel, "Wavegrad")
|
MyModel = getattr(MyModel, "Wavegrad")
|
||||||
else:
|
else:
|
||||||
MyModel = getattr(MyModel, to_camel(config.model))
|
try:
|
||||||
raise ValueError(f"Model {config.model} not exist!")
|
MyModel = getattr(MyModel, to_camel(config.model))
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ValueError(f"Model {config.model} not exist!") from e
|
||||||
model = MyModel(config)
|
model = MyModel(config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
|
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
|
||||||
checkpoint_path (str): Checkpoint file path.
|
checkpoint_path (str): Checkpoint file path.
|
||||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||||
"""
|
"""
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
# band-aid for older than v0.0.15 GAN models
|
# band-aid for older than v0.0.15 GAN models
|
||||||
if "model_disc" in state:
|
if "model_disc" in state:
|
||||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||||
|
|
|
@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DiscriminatorP(2),
|
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(3),
|
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(5),
|
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(7),
|
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(11),
|
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
x (Tensor): input waveform.
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[List[Tensor]]: list of scores from each discriminator.
|
[List[Tensor]]: list of scores from each discriminator.
|
||||||
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
|
|
|
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
upsample_factors,
|
upsample_factors,
|
||||||
inference_padding=5,
|
inference_padding=5,
|
||||||
|
cond_channels=0,
|
||||||
|
conv_pre_weight_norm=True,
|
||||||
|
conv_post_weight_norm=True,
|
||||||
|
conv_post_bias=True,
|
||||||
):
|
):
|
||||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||||
|
|
||||||
|
@ -216,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
# post convolution layer
|
# post convolution layer
|
||||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||||
|
if cond_channels > 0:
|
||||||
|
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
if not conv_pre_weight_norm:
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
|
||||||
|
if not conv_post_weight_norm:
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
def forward(self, x, g=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): conditioning input tensor.
|
x (Tensor): feature input tensor.
|
||||||
|
g (Tensor): global conditioning input tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: output waveform.
|
Tensor: output waveform.
|
||||||
|
@ -231,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
Tensor: [B, 1, T]
|
Tensor: [B, 1, T]
|
||||||
"""
|
"""
|
||||||
o = self.conv_pre(x)
|
o = self.conv_pre(x)
|
||||||
|
if hasattr(self, "cond_layer"):
|
||||||
|
o = o + self.cond_layer(g)
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||||
o = self.ups[i](o)
|
o = self.ups[i](o)
|
||||||
|
@ -275,7 +292,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.melgan import ResidualStack
|
from TTS.vocoder.layers.melgan import ResidualStack
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||||
|
|
||||||
|
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
|
||||||
class UnivnetGenerator(torch.nn.Module):
|
class UnivnetGenerator(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
out_channels,
|
out_channels: int,
|
||||||
hidden_channels,
|
hidden_channels: int,
|
||||||
cond_channels,
|
cond_channels: int,
|
||||||
upsample_factors,
|
upsample_factors: List[int],
|
||||||
lvc_layers_each_block,
|
lvc_layers_each_block: int,
|
||||||
lvc_kernel_size,
|
lvc_kernel_size: int,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels: int,
|
||||||
kpnet_conv_size,
|
kpnet_conv_size: int,
|
||||||
dropout,
|
dropout: float,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
):
|
):
|
||||||
|
"""Univnet Generator network.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
out_channels (int): Number of channels of the output tensor.
|
||||||
|
hidden_channels (int): Number of hidden network channels.
|
||||||
|
cond_channels (int): Number of channels of the conditioning tensors.
|
||||||
|
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||||
|
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||||
|
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||||
|
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||||
|
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||||
|
dropout (float): Dropout rate.
|
||||||
|
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.datasets import WaveGradDataset
|
from TTS.vocoder.datasets import WaveGradDataset
|
||||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||||
|
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||||
|
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
"""Load TF Vocoder model"""
|
"""Load TF Vocoder model"""
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
|
||||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
# same model binary if outputpath is provided
|
# same model binary if outputpath is provided
|
||||||
with open(output_path, "wb") as f:
|
with fsspec.open(output_path, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
return None
|
return None
|
||||||
return tflite_model
|
return tflite_model
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
def interpolate_vocoder_input(scale_factor, spec):
|
def interpolate_vocoder_input(scale_factor, spec):
|
||||||
|
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def plot_results(y_hat, y, ap, name_prefix):
|
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||||
"""Plot vocoder model results"""
|
"""Plot the predicted and the real waveform and their spectrograms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_hat (torch.tensor): Predicted waveform.
|
||||||
|
y (torch.tensor): Real waveform.
|
||||||
|
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||||
|
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: output figures keyed by the name of the figures.
|
||||||
|
""" """Plot vocoder model results"""
|
||||||
|
if name_prefix is None:
|
||||||
|
name_prefix = ""
|
||||||
|
|
||||||
# select an instance from batch
|
# select an instance from batch
|
||||||
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
|
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||||
y = y[0].squeeze(0).detach().cpu().numpy()
|
y = y[0].squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
spec_fake = ap.melspectrogram(y_hat).T
|
spec_fake = ap.melspectrogram(y_hat).T
|
||||||
spec_real = ap.melspectrogram(y).T
|
spec_real = ap.melspectrogram(y).T
|
||||||
|
|
|
@ -2,4 +2,5 @@ furo
|
||||||
myst-parser == 0.15.1
|
myst-parser == 0.15.1
|
||||||
sphinx == 4.0.2
|
sphinx == 4.0.2
|
||||||
sphinx_inline_tabs
|
sphinx_inline_tabs
|
||||||
sphinx_copybutton
|
sphinx_copybutton
|
||||||
|
linkify-it-py
|
|
@ -68,6 +68,8 @@ extensions = [
|
||||||
"sphinx_inline_tabs",
|
"sphinx_inline_tabs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
myst_enable_extensions = ['linkify',]
|
||||||
|
|
||||||
# 'sphinxcontrib.katex',
|
# 'sphinxcontrib.katex',
|
||||||
# 'sphinx.ext.autosectionlabel',
|
# 'sphinx.ext.autosectionlabel',
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@
|
||||||
:caption: `tts` Models
|
:caption: `tts` Models
|
||||||
|
|
||||||
models/glow_tts.md
|
models/glow_tts.md
|
||||||
|
models/vits.md
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Glow TTS
|
||||||
|
|
||||||
|
Glow TTS is a normalizing flow model for text-to-speech. It is built on the generic Glow model that is previously
|
||||||
|
used in computer vision and vocoder models. It uses "monotonic alignment search" (MAS) to fine the text-to-speech alignment
|
||||||
|
and uses the output to train a separate duration predictor network for faster inference run-time.
|
||||||
|
|
||||||
|
## Important resources & papers
|
||||||
|
- GlowTTS: https://arxiv.org/abs/2005.11129
|
||||||
|
- Glow (Generative Flow with invertible 1x1 Convolutions): https://arxiv.org/abs/1807.03039
|
||||||
|
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
|
||||||
|
|
||||||
|
## GlowTTS Config
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: TTS.tts.configs.glow_tts_config.GlowTTSConfig
|
||||||
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## GlowTTS Model
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: TTS.tts.models.glow_tts.GlowTTS
|
||||||
|
:members:
|
||||||
|
```
|
|
@ -0,0 +1,33 @@
|
||||||
|
# VITS
|
||||||
|
|
||||||
|
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
|
||||||
|
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
|
||||||
|
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
|
||||||
|
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
||||||
|
It is a feed-forward model with x67.12 real-time factor on a GPU.
|
||||||
|
|
||||||
|
## Important resources & papers
|
||||||
|
- VITS: https://arxiv.org/pdf/2106.06103.pdf
|
||||||
|
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
|
||||||
|
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
|
||||||
|
- Generative Adversarial Networks: https://arxiv.org/abs/1406.2661
|
||||||
|
- HiFiGAN: https://arxiv.org/abs/2010.05646
|
||||||
|
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
|
||||||
|
|
||||||
|
## VitsConfig
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: TTS.tts.configs.vits_config.VitsConfig
|
||||||
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## VitsArgs
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: TTS.tts.models.vits.VitsArgs
|
||||||
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## Vits Model
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: TTS.tts.models.vits.Vits
|
||||||
|
:members:
|
||||||
|
```
|
|
@ -85,6 +85,7 @@ We still support running training from CLI like in the old days. The same traini
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
|
"run_name": "my_run",
|
||||||
"model": "glow_tts",
|
"model": "glow_tts",
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size": 16,
|
"eval_batch_size": 16,
|
||||||
|
|
|
@ -25,9 +25,7 @@
|
||||||
"import umap\n",
|
"import umap\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
|
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
|
||||||
"from TTS.utils.audio import AudioProcessor
|
"from TTS.utils.audio import AudioProcessor\n",
|
||||||
|
|
||||||
\n",
|
|
||||||
"from TTS.tts.utils.generic_utils import load_config\n",
|
"from TTS.tts.utils.generic_utils import load_config\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from bokeh.io import output_notebook, show\n",
|
"from bokeh.io import output_notebook, show\n",
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.tts.configs import AlignTTSConfig
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.tts.configs import BaseDatasetConfig
|
from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig
|
||||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
|
)
|
||||||
config = AlignTTSConfig(
|
config = AlignTTSConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
|
@ -23,8 +23,8 @@ config = AlignTTSConfig(
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config]
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.tts.configs import GlowTTSConfig
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.tts.configs import BaseDatasetConfig
|
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
|
||||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
|
)
|
||||||
config = GlowTTSConfig(
|
config = GlowTTSConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
|
@ -23,8 +23,8 @@ config = GlowTTSConfig(
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config]
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = HifiganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.vocoder.configs import MultibandMelganConfig
|
from TTS.vocoder.configs import MultibandMelganConfig
|
||||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
config = MultibandMelganConfig(
|
config = MultibandMelganConfig(
|
||||||
|
@ -25,6 +24,6 @@ config = MultibandMelganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.vocoder.configs import UnivnetConfig
|
from TTS.vocoder.configs import UnivnetConfig
|
||||||
|
|
||||||
|
@ -25,6 +24,6 @@ config = UnivnetConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
|
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
|
)
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
win_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
num_mels=80,
|
||||||
|
preemphasis=0.0,
|
||||||
|
ref_level_db=20,
|
||||||
|
log_func="np.log",
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=45,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
spec_gain=1.0,
|
||||||
|
signal_norm=False,
|
||||||
|
do_amp_to_db_linear=False,
|
||||||
|
)
|
||||||
|
config = VitsConfig(
|
||||||
|
audio=audio_config,
|
||||||
|
run_name="vits_ljspeech",
|
||||||
|
batch_size=48,
|
||||||
|
eval_batch_size=16,
|
||||||
|
batch_group_size=0,
|
||||||
|
num_loader_workers=4,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
print_step=25,
|
||||||
|
print_eval=True,
|
||||||
|
mixed_precision=True,
|
||||||
|
max_seq_len=5000,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=[dataset_config],
|
||||||
|
)
|
||||||
|
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||||
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
||||||
|
trainer.fit()
|
|
@ -1,10 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, init_training
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.trainer import TrainingArgs
|
|
||||||
from TTS.vocoder.configs import WavegradConfig
|
from TTS.vocoder.configs import WavegradConfig
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
config = WavegradConfig(
|
config = WavegradConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
@ -24,6 +22,6 @@ config = WavegradConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, init_training, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.vocoder.configs import WavernnConfig
|
from TTS.vocoder.configs import WavernnConfig
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
config = WavernnConfig(
|
config = WavernnConfig(
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
|
@ -25,6 +24,6 @@ config = WavernnConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,3 +24,4 @@ mecab-python3==1.0.3
|
||||||
unidic-lite==1.0.8
|
unidic-lite==1.0.8
|
||||||
# gruut+supported langs
|
# gruut+supported langs
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
||||||
|
fsspec>=2021.04.0
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=True,
|
compute_linear_spec=True,
|
||||||
|
return_wav=True,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
meta_data=items,
|
meta_data=items,
|
||||||
characters=c.characters,
|
characters=c.characters,
|
||||||
|
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_target = data[6]
|
stop_target = data[6]
|
||||||
item_idx = data[7]
|
item_idx = data[7]
|
||||||
|
wavs = data[11]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
|
||||||
assert isinstance(speaker_name[0], str)
|
assert isinstance(speaker_name[0], str)
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.audio["num_mels"]
|
assert mel_input.shape[2] == c.audio["num_mels"]
|
||||||
|
assert (
|
||||||
|
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
|
||||||
|
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
|
||||||
|
|
||||||
|
# make sure that the computed mels and the waveform match and correctly computed
|
||||||
|
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
||||||
|
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
||||||
|
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
|
||||||
|
assert abs(mel_diff.sum()) < 1e-5
|
||||||
|
|
||||||
# check normalization ranges
|
# check normalization ranges
|
||||||
if self.ap.symmetric_norm:
|
if self.ap.symmetric_norm:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
assert mel_input.max() <= self.ap.max_norm
|
||||||
|
|
|
@ -27,6 +27,7 @@ config = AlignTTSConfig(
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
config.audio.do_trim_silence = True
|
config.audio.do_trim_silence = True
|
||||||
config.audio.trim_db = 60
|
config.audio.trim_db = 60
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
|
@ -29,6 +29,7 @@ config = Tacotron2Config(
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
],
|
],
|
||||||
d_vector_file="tests/data/ljspeech/speakers.json",
|
d_vector_file="tests/data/ljspeech/speakers.json",
|
||||||
|
d_vector_dim=256,
|
||||||
max_decoder_steps=50,
|
max_decoder_steps=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
|
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test vanilla Tacotron2 model."""
|
||||||
|
|
||||||
def test_train_step(self): # pylint: disable=no-self-use
|
def test_train_step(self): # pylint: disable=no-self-use
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
config.num_speakers = 1
|
||||||
|
|
||||||
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
mel_lengths[0] = 30
|
||||||
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
|
for idx in mel_lengths:
|
||||||
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
model = Tacotron2(config).to(device)
|
||||||
|
model.train()
|
||||||
|
model_ref = copy.deepcopy(model)
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
assert (param - param_ref).sum() == 0, param
|
||||||
|
count += 1
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||||
|
for i in range(5):
|
||||||
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||||
|
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||||
|
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||||
|
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||||
|
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# check parameter changes
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
# ignore pre-higway layer since it works conditional
|
||||||
|
# if count not in [145, 59]:
|
||||||
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakerTacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with speaker embedding layer"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_train_step():
|
||||||
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 5
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
config.d_vector_dim = 55
|
||||||
model = Tacotron2(config).to(device)
|
model = Tacotron2(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
|
@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def test_train_step():
|
|
||||||
config = config_global.copy()
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
|
||||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
|
||||||
mel_lengths[0] = 30
|
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
|
||||||
speaker_ids = torch.rand(8, 55).to(device)
|
|
||||||
|
|
||||||
for idx in mel_lengths:
|
|
||||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
|
||||||
config.d_vector_dim = 55
|
|
||||||
model = Tacotron2(config).to(device)
|
|
||||||
model.train()
|
|
||||||
model_ref = copy.deepcopy(model)
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
|
||||||
count += 1
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
|
||||||
for i in range(5):
|
|
||||||
outputs = model.forward(
|
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
|
|
||||||
)
|
|
||||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
|
||||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
|
||||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
|
||||||
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# check parameter changes
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
|
||||||
# if count not in [145, 59]:
|
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
|
||||||
count, param.shape, param, param_ref
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding"""
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 10
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
|
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_d_vector_file = True
|
||||||
|
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import Tacotron2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
config = Tacotron2Config(
|
||||||
|
r=5,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
print_eval=True,
|
||||||
|
max_decoder_steps=50,
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
|
||||||
|
f"--coqpit.output_path file://{output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.test_delay_epochs 0 "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
config.num_speakers = 1
|
||||||
|
|
||||||
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
|
input_lengths[-1] = 128
|
||||||
|
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
||||||
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
mel_lengths[-1] = mel_spec.size(1)
|
||||||
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
|
for idx in mel_lengths:
|
||||||
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
|
model.train()
|
||||||
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
|
model_ref = copy.deepcopy(model)
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
assert (param - param_ref).sum() == 0, param
|
||||||
|
count += 1
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||||
|
for _ in range(5):
|
||||||
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||||
|
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||||
|
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# check parameter changes
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
# ignore pre-higway layer since it works conditional
|
||||||
|
# if count not in [145, 59]:
|
||||||
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def test_train_step():
|
||||||
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 5
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
|
@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
config.d_vector_dim = 55
|
||||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
|
@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def test_train_step():
|
|
||||||
config = config_global.copy()
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
|
||||||
input_lengths[-1] = 128
|
|
||||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
|
||||||
mel_lengths[-1] = mel_spec.size(1)
|
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
|
||||||
|
|
||||||
for idx in mel_lengths:
|
|
||||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
|
||||||
config.d_vector_dim = 55
|
|
||||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
|
||||||
model.train()
|
|
||||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
|
||||||
count += 1
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
|
||||||
for _ in range(5):
|
|
||||||
outputs = model.forward(
|
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
|
||||||
)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
|
||||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
|
||||||
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# check parameter changes
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
|
||||||
# if count not in [145, 59]:
|
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
|
||||||
count, param.shape, param, param_ref
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 10
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
|
@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_d_vector_file = True
|
||||||
|
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue