Merge pull request #717 from coqui-ai/dev

v0.2.0
This commit is contained in:
Eren Gölge 2021-08-11 10:08:04 +02:00 committed by GitHub
commit 01a2b0b5c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
94 changed files with 3610 additions and 657 deletions

View File

@ -1,10 +1,15 @@
--- # Pull request guidelines
name: 'Contribution Guideline '
about: Refer to Contirbution Guideline
title: ''
labels: ''
assignees: ''
--- Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support!
👐 Please check our [CONTRIBUTION GUIDELINE](https://github.com/coqui-ai/TTS#contribution-guidelines). This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS).
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS):
- Protects you, Coqui, and the users of the code.
- Does not change your rights to use your contributions for any purpose.
- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.

5
.gitignore vendored
View File

@ -136,6 +136,7 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
temp_build/* temp_build/*
recipes/WIP/* recipes/WIP/*
recipes/ljspeech/LJSpeech-1.1/* recipes/ljspeech/LJSpeech-1.1/*
recipes/ljspeech/tacotron2-DDC/LJSpeech-1.1/*
events.out* events.out*
old_configs/* old_configs/*
model_importers/* model_importers/*
@ -152,4 +153,6 @@ output.wav
tts_output.wav tts_output.wav
deps.json deps.json
speakers.json speakers.json
internal/* internal/*
*_pitch.npy
*_phoneme.npy

View File

@ -4,7 +4,7 @@
help: help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
target_dirs := tests TTS notebooks target_dirs := tests TTS notebooks recipes
test_all: ## run tests and don't stop on an error. test_all: ## run tests and don't stop on an error.
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id

View File

@ -73,10 +73,13 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802) - Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950) - Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
### Attention Methods ### Attention Methods
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969) - Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006) - Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
- Graves Attention: [paper](https://arxiv.org/abs/1907.09006) - Graves Attention: [paper](https://arxiv.org/abs/1910.10288)
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/) - Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
- Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf) - Dynamic Convolutional Attention: [paper](https://arxiv.org/pdf/1910.10288.pdf)

View File

@ -1,7 +1,7 @@
{ {
"tts_models":{ "tts_models": {
"en":{ "en": {
"ek1":{ "ek1": {
"tacotron2": { "tacotron2": {
"description": "EK1 en-rp tacotron2 by NMStoker", "description": "EK1 en-rp tacotron2 by NMStoker",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
@ -9,7 +9,7 @@
"commit": "c802255" "commit": "c802255"
} }
}, },
"ljspeech":{ "ljspeech": {
"tacotron2-DDC": { "tacotron2-DDC": {
"description": "Tacotron2 with Double Decoder Consistency.", "description": "Tacotron2 with Double Decoder Consistency.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--ljspeech--tacotron2-DDC.zip",
@ -17,9 +17,18 @@
"commit": "bae2ad0f", "commit": "bae2ad0f",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "", "license": "",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
}, },
"glow-tts":{ "tacotron2-DDC_ph": {
"description": "Tacotron2 with Double Decoder Consistency with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--tacotronDDC_ph.zip",
"default_vocoder": "vocoder_models/en/ljspeech/univnet",
"commit": "3900448",
"author": "Eren Gölge @erogol",
"license": "",
"contact": "egolge@coqui.com"
},
"glow-tts": {
"description": "", "description": "",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--en--ljspeech--glow-tts.zip",
"stats_file": null, "stats_file": null,
@ -27,7 +36,7 @@
"commit": "", "commit": "",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
}, },
"tacotron2-DCA": { "tacotron2-DCA": {
"description": "", "description": "",
@ -36,19 +45,28 @@
"commit": "", "commit": "",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
}, },
"speedy-speech-wn":{ "speedy-speech-wn": {
"description": "Speedy Speech model with wavenet decoder.", "description": "Speedy Speech model with wavenet decoder.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan", "default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
"commit": "77b6145", "commit": "77b6145",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
},
"vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren Gölge @erogol",
"license": "",
"contact": "egolge@coqui.com"
} }
}, },
"vctk":{ "vctk": {
"sc-glow-tts": { "sc-glow-tts": {
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.", "description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
@ -56,12 +74,19 @@
"commit": "b531fa69", "commit": "b531fa69",
"author": "Edresson Casanova", "author": "Edresson Casanova",
"license": "", "license": "",
"contact":"" "contact": ""
},
"vits": {
"description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--vctk--vits.zip",
"default_vocoder": null,
"commit": "3900448",
"author": "Eren @erogol",
"license": "",
"contact": "egolge@coqui.ai"
} }
}, },
"sam":{ "sam": {
"tacotron-DDC": { "tacotron-DDC": {
"description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.", "description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/tts_models--en--sam--tacotron_DDC.zip",
@ -69,46 +94,47 @@
"commit": "bae2ad0f", "commit": "bae2ad0f",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "", "license": "",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
} }
} }
}, },
"es":{ "es": {
"mai":{ "mai": {
"tacotron2-DDC":{ "tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--es--mai--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"commit": "", "commit": "",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
} }
} }
}, },
"fr":{ "fr": {
"mai":{ "mai": {
"tacotron2-DDC":{ "tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/tts_models--fr--mai--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"commit": "", "commit": "",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
} }
} }
}, },
"zh-CN":{ "zh-CN": {
"baker":{ "baker": {
"tacotron2-DDC-GST":{ "tacotron2-DDC-GST": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip",
"commit": "unknown", "commit": "unknown",
"author": "@kirianguiller" "author": "@kirianguiller",
"default_vocoder": null
} }
} }
}, },
"nl":{ "nl": {
"mai":{ "mai": {
"tacotron2-DDC":{ "tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--nl--mai--tacotron2-DDC.zip",
"author": "@r-dh", "author": "@r-dh",
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan", "default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
@ -117,20 +143,9 @@
} }
} }
}, },
"ru":{ "de": {
"ruslan":{ "thorsten": {
"tacotron2-DDC":{ "tacotron2-DCA": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--ru--ruslan--tacotron2-DDC.zip",
"author": "@erogol",
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
"license":"",
"contact": "egolge@coqui.com"
}
}
},
"de":{
"thorsten":{
"tacotron2-DCA":{
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
"default_vocoder": "vocoder_models/de/thorsten/fullband-melgan", "default_vocoder": "vocoder_models/de/thorsten/fullband-melgan",
"author": "@thorstenMueller", "author": "@thorstenMueller",
@ -138,9 +153,9 @@
} }
} }
}, },
"ja":{ "ja": {
"kokoro":{ "kokoro": {
"tacotron2-DDC":{ "tacotron2-DDC": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip",
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad", "default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.", "description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
@ -150,54 +165,62 @@
} }
} }
}, },
"vocoder_models":{ "vocoder_models": {
"universal":{ "universal": {
"libri-tts":{ "libri-tts": {
"wavegrad":{ "wavegrad": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--wavegrad.zip",
"commit": "ea976b0", "commit": "ea976b0",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
}, },
"fullband-melgan":{ "fullband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--universal--libri-tts--fullband-melgan.zip",
"commit": "4132240", "commit": "4132240",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
} }
} }
}, },
"en": { "en": {
"ek1":{ "ek1": {
"wavegrad": { "wavegrad": {
"description": "EK1 en-rp wavegrad by NMStoker", "description": "EK1 en-rp wavegrad by NMStoker",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--en--ek1--wavegrad.zip",
"commit": "c802255" "commit": "c802255"
} }
}, },
"ljspeech":{ "ljspeech": {
"multiband-melgan":{ "multiband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.9/vocoder_models--en--ljspeech--mulitband-melgan.zip",
"commit": "ea976b0", "commit": "ea976b0",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "MPL", "license": "MPL",
"contact":"egolge@coqui.com" "contact": "egolge@coqui.com"
}, },
"hifigan_v2":{ "hifigan_v2": {
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.", "description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--ljspeech-hifigan_v2.zip",
"commit": "bae2ad0f", "commit": "bae2ad0f",
"author": "@erogol", "author": "@erogol",
"license": "", "license": "",
"contact": "egolge@coqui.ai" "contact": "egolge@coqui.ai"
},
"univnet": {
"description": "UnivNet model trained on LJSpeech to complement the TacotronDDC_ph model.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--en--ljspeech--univnet.zip",
"commit": "3900448",
"author": "Eren @erogol",
"license": "",
"contact": "egolge@coqui.ai"
} }
}, },
"vctk":{ "vctk": {
"hifigan_v2":{ "hifigan_v2": {
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts", "description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/vocoder_model--en--vctk--hifigan_v2.zip",
"commit": "2f07160", "commit": "2f07160",
"author": "Edresson Casanova", "author": "Edresson Casanova",
"license": "", "license": "",
@ -205,9 +228,9 @@
} }
}, },
"sam": { "sam": {
"hifigan_v2":{ "hifigan_v2": {
"description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC", "description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC",
"github_rls_url":"https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.13/vocoder_models--en--sam--hifigan_v2.zip",
"commit": "2f07160", "commit": "2f07160",
"author": "Eren Gölge @erogol", "author": "Eren Gölge @erogol",
"license": "", "license": "",
@ -215,28 +238,38 @@
} }
} }
}, },
"nl":{ "nl": {
"mai":{ "mai": {
"parallel-wavegan":{ "parallel-wavegan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/vocoder_models--nl--mai--parallel-wavegan.zip",
"author": "@r-dh", "author": "@r-dh",
"commit": "unknown" "commit": "unknown"
} }
} }
}, },
"de":{ "de": {
"thorsten":{ "thorsten": {
"wavegrad":{ "wavegrad": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/vocoder_models--de--thorsten--wavegrad.zip",
"author": "@thorstenMueller", "author": "@thorstenMueller",
"commit": "unknown" "commit": "unknown"
}, },
"fullband-melgan":{ "fullband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip", "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.3/vocoder_models--de--thorsten--fullband-melgan.zip",
"author": "@thorstenMueller", "author": "@thorstenMueller",
"commit": "unknown" "commit": "unknown"
} }
} }
},
"ja": {
"kokoro": {
"hifigan_v1": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/vocoder_models--ja--kokoro--hifigan_v1.zip",
"description": "HifiGAN model trained for kokoro dataset by @kaiidams",
"author": "@kaiidams",
"commit": "3900448"
}
}
} }
} }
} }

View File

@ -6,7 +6,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from TTS.utils.io import load_config from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, compare_torch_tf,
convert_tf_name, convert_tf_name,
@ -33,7 +33,7 @@ num_speakers = 0
# init torch model # init torch model
model = setup_generator(c) model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.remove_weight_norm() model.remove_weight_norm()

View File

@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects") sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["CUDA_VISIBLE_DEVICES"] = ""
@ -32,7 +32,7 @@ num_speakers = 0
# init torch model # init torch model
model = setup_model(c) model = setup_model(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)

View File

@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c) model = setup_model(c)
# restore model # restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu") checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if use_cuda: if use_cuda:

View File

@ -208,7 +208,7 @@ def main():
if args.vocoder_name is not None and not args.vocoder_path: if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custome model paths # CASE3: set custom model paths
if args.model_path is not None: if args.model_path is not None:
model_path = args.model_path model_path = args.model_path
config_path = args.config_path config_path = args.config_path

View File

@ -17,6 +17,7 @@ from TTS.trainer import init_training
from TTS.tts.datasets import load_meta_data from TTS.tts.datasets import load_meta_data
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update from TTS.utils.training import NoamLR, check_update
@ -115,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
"step_time": step_time, "step_time": step_time,
"avg_loader_time": avg_loader_time, "avg_loader_time": avg_loader_time,
} }
tb_logger.tb_train_epoch_stats(global_step, train_stats) dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = { figures = {
# FIXME: not constant # FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
} }
tb_logger.tb_train_figures(global_step, figures) dashboard_logger.train_figures(global_step, figures)
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
print( print(
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
raise Exception("The %s not is a loss supported" % c.loss) raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = load_fsspec(args.restore_path)
try: try:
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
@ -207,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
try: try:
main(args) main(args)

View File

@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
def main(): def main():
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```""" """Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
trainer.fit() trainer.fit()

View File

@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
def main(): def main():
try: try:
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()
except KeyboardInterrupt: except KeyboardInterrupt:
remove_experiment_folder(output_path) remove_experiment_folder(output_path)

View File

@ -3,6 +3,7 @@ import os
import re import re
from typing import Dict from typing import Dict
import fsspec
import yaml import yaml
from coqpit import Coqpit from coqpit import Coqpit
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path): def read_json_with_comments(json_path):
"""for backward compat.""" """for backward compat."""
# fallback to json # fallback to json
with open(json_path, "r", encoding="utf-8") as f: with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
# handle comments # handle comments
input_str = re.sub(r"\\\n", "", input_str) input_str = re.sub(r"\\\n", "", input_str)
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
config_dict = {} config_dict = {}
ext = os.path.splitext(config_path)[1] ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"): if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f: with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
elif ext == ".json": elif ext == ".json":
try: try:
with open(config_path, "r", encoding="utf-8") as f: with fsspec.open(config_path, "r", encoding="utf-8") as f:
input_str = f.read() data = json.load(f)
data = json.loads(input_str)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
# backwards compat. # backwards compat.
data = read_json_with_comments(config_path) data = read_json_with_comments(config_path)

View File

@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
do_trim_silence (bool): do_trim_silence (bool):
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
trim_db (int): trim_db (int):
Silence threshold used for silence trimming. Defaults to 45. Silence threshold used for silence trimming. Defaults to 45.
power (float): power (float):
@ -79,7 +83,7 @@ class BaseAudioConfig(Coqpit):
preemphasis: float = 0.0 preemphasis: float = 0.0
ref_level_db: int = 20 ref_level_db: int = 20
do_sound_norm: bool = False do_sound_norm: bool = False
log_func = "np.log10" log_func: str = "np.log10"
# silence trimming # silence trimming
do_trim_silence: bool = True do_trim_silence: bool = True
trim_db: int = 45 trim_db: int = 45
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
mel_fmin: float = 0.0 mel_fmin: float = 0.0
mel_fmax: float = None mel_fmax: float = None
spec_gain: int = 20 spec_gain: int = 20
do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True
# normalization params # normalization params
signal_norm: bool = True signal_norm: bool = True
min_level_db: int = -100 min_level_db: int = -100
@ -182,51 +188,87 @@ class BaseTrainingConfig(Coqpit):
Args: Args:
model (str): model (str):
Name of the model that is used in the training. Name of the model that is used in the training.
run_name (str): run_name (str):
Name of the experiment. This prefixes the output folder name. Name of the experiment. This prefixes the output folder name.
run_description (str): run_description (str):
Short description of the experiment. Short description of the experiment.
epochs (int): epochs (int):
Number training epochs. Defaults to 10000. Number training epochs. Defaults to 10000.
batch_size (int): batch_size (int):
Training batch size. Training batch size.
eval_batch_size (int): eval_batch_size (int):
Validation batch size. Validation batch size.
mixed_precision (bool): mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases. it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool): run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True. Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int): test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time. results, hence waiting for a couple of epochs might save some time.
print_eval (bool): print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```. the end of the evaluation. Default to ```False```.
print_step (int): print_step (int):
Number of steps required to print the next training log. Number of steps required to print the next training log.
tb_plot_step (int):
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard. Number of steps required to log training on Tensorboard.
tb_model_param_stats (bool):
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```. Defaults to ```False```.
project_name (str):
Name of the project. Defaults to config.model
wandb_entity (str):
Name of W&B entity/team. Enables collaboration across a team or org.
log_model_step (int):
Number of steps required to log a checkpoint as W&B artifact
save_step (int):ipt save_step (int):ipt
Number of steps required to save the next checkpoint. Number of steps required to save the next checkpoint.
checkpoint (bool): checkpoint (bool):
Enable / Disable checkpointing. Enable / Disable checkpointing.
keep_all_best (bool): keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```. to ```False```.
keep_after (int): keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000. to 10000.
num_loader_workers (int): num_loader_workers (int):
Number of workers for training time dataloader. Number of workers for training time dataloader.
num_eval_loader_workers (int): num_eval_loader_workers (int):
Number of workers for evaluation time dataloader. Number of workers for evaluation time dataloader.
output_path (str): output_path (str):
Path for training output folder. The nonexist part of the given path is created automatically. Path for training output folder, either a local file path or other
All training outputs are saved there. URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
""" """
model: str = None model: str = None
@ -237,14 +279,19 @@ class BaseTrainingConfig(Coqpit):
batch_size: int = None batch_size: int = None
eval_batch_size: int = None eval_batch_size: int = None
mixed_precision: bool = False mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params # eval params
run_eval: bool = True run_eval: bool = True
test_delay_epochs: int = 0 test_delay_epochs: int = 0
print_eval: bool = False print_eval: bool = False
# logging # logging
dashboard_logger: str = "tensorboard"
print_step: int = 25 print_step: int = 25
tb_plot_step: int = 100 plot_step: int = 100
tb_model_param_stats: bool = False model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
# checkpointing # checkpointing
save_step: int = 10000 save_step: int = 10000
checkpoint: bool = True checkpoint: bool = True

View File

@ -103,8 +103,8 @@ synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
) )
use_multi_speaker = synthesizer.tts_model.speaker_manager is not None and synthesizer.tts_model.num_speakers > 1 use_multi_speaker = hasattr(synthesizer.tts_model, "speaker_manager") and synthesizer.tts_model.num_speakers > 1
speaker_manager = synthesizer.tts_model.speaker_manager if hasattr(synthesizer.tts_model, "speaker_manager") else None speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
# TODO: set this from SpeakerManager # TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False) use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__) app = Flask(__name__)

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from TTS.utils.io import load_fsspec
class LSTMWithProjection(nn.Module): class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size): def __init__(self, input_size, hidden_size, proj_size):
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
# pylint: disable=unused-argument, redefined-builtin # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if use_cuda: if use_cuda:
self.cuda() self.cuda()

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from TTS.utils.io import load_fsspec
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
return embeddings return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if use_cuda: if use_cuda:
self.cuda() self.cuda()

View File

@ -6,11 +6,11 @@ import re
from multiprocessing import Manager from multiprocessing import Manager
import numpy as np import numpy as np
import torch
from scipy import signal from scipy import signal
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class Storage(object): class Storage(object):
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
torch.save(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
bestmodel_path = "best_model.pth.tar" bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path) save_fsspec(state, bestmodel_path)
return best_loss return best_loss

View File

@ -1,7 +1,7 @@
import datetime import datetime
import os import os
import torch from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step): def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
torch.save(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
bestmodel_path = "best_model.pth.tar" bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path) save_fsspec(state, bestmodel_path)
return best_loss return best_loss

View File

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import glob
import importlib import importlib
import logging import logging
import multiprocessing
import os import os
import platform import platform
import re import re
@ -12,7 +12,9 @@ import traceback
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -29,18 +31,20 @@ from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (
KeepAverage, KeepAverage,
count_parameters, count_parameters,
create_experiment_folder, get_experiment_folder_path,
get_git_branch, get_git_branch,
remove_experiment_folder, remove_experiment_folder,
set_init_dict, set_init_dict,
to_cuda, to_cuda,
) )
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
multiprocessing.set_start_method("fork")
if platform.system() != "Windows": if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973 # https://github.com/pytorch/pytorch/issues/973
import resource import resource
@ -48,6 +52,7 @@ if platform.system() != "Windows":
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
@ -87,7 +92,7 @@ class Trainer:
config: Coqpit, config: Coqpit,
output_path: str, output_path: str,
c_logger: ConsoleLogger = None, c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None, dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
model: nn.Module = None, model: nn.Module = None,
cudnn_benchmark: bool = False, cudnn_benchmark: bool = False,
) -> None: ) -> None:
@ -112,7 +117,7 @@ class Trainer:
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
console logger is used. Defaults to None. console logger is used. Defaults to None.
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used. dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
Defaults to None. Defaults to None.
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
@ -134,8 +139,8 @@ class Trainer:
Running trainer on a config. Running trainer on a config.
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,) >>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) >>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger) >>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
>>> trainer.fit() >>> trainer.fit()
TODO: TODO:
@ -148,22 +153,24 @@ class Trainer:
# set and initialize Pytorch runtime # set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None: if config is None:
# parse config from console arguments # parse config from console arguments
config, output_path, _, c_logger, tb_logger = process_args(args) config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path self.output_path = output_path
self.args = args self.args = args
self.config = config self.config = config
self.config.output_log_path = output_path
# init loggers # init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None: self.dashboard_logger = dashboard_logger
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0) if self.dashboard_logger is None:
else: self.dashboard_logger = init_logger(config)
self.tb_logger = tb_logger
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file) self._setup_logger_config(log_file)
@ -173,7 +180,6 @@ class Trainer:
self.best_loss = float("inf") self.best_loss = float("inf")
self.train_loader = None self.train_loader = None
self.eval_loader = None self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None self.keep_avg_train = None
self.keep_avg_eval = None self.keep_avg_eval = None
@ -184,7 +190,7 @@ class Trainer:
# init audio processor # init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict()) self.ap = AudioProcessor(**self.config.audio.to_dict())
# load dataset samples # load data samples
# TODO: refactor this # TODO: refactor this
if "datasets" in self.config: if "datasets" in self.config:
# load data for `tts` models # load data for `tts` models
@ -205,6 +211,10 @@ class Trainer:
else: else:
self.model = self.get_model(self.config) self.model = self.get_model(self.config)
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.model) self.criterion = self.get_criterion(self.model)
@ -274,9 +284,9 @@ class Trainer:
""" """
# TODO: better model setup # TODO: better model setup
try: try:
model = setup_tts_model(config)
except ModuleNotFoundError:
model = setup_vocoder_model(config) model = setup_vocoder_model(config)
except ModuleNotFoundError:
model = setup_tts_model(config)
return model return model
def restore_model( def restore_model(
@ -309,7 +319,7 @@ class Trainer:
return obj return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path)) print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path) checkpoint = load_fsspec(restore_path)
try: try:
print(" > Restoring Model...") print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
@ -417,7 +427,7 @@ class Trainer:
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
config: Coqpit, config: Coqpit,
optimizer_idx: int = None, optimizer_idx: int = None,
) -> Tuple[Dict, Dict, int, torch.Tensor]: ) -> Tuple[Dict, Dict, int]:
"""Perform a forward - backward pass and run the optimizer. """Perform a forward - backward pass and run the optimizer.
Args: Args:
@ -426,7 +436,7 @@ class Trainer:
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
scaler (AMPScaler): AMP scaler. scaler (AMPScaler): AMP scaler.
criterion (nn.Module): Model's criterion. criterion (nn.Module): Model's criterion.
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer. scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
config (Coqpit): Model config. config (Coqpit): Model config.
optimizer_idx (int, optional): Target optimizer being used. Defaults to None. optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
@ -436,6 +446,7 @@ class Trainer:
Returns: Returns:
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
""" """
step_start_time = time.time() step_start_time = time.time()
# zero-out optimizer # zero-out optimizer
optimizer.zero_grad() optimizer.zero_grad()
@ -448,11 +459,11 @@ class Trainer:
# skip the rest # skip the rest
if outputs is None: if outputs is None:
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
return None, {}, step_time, 0 return None, {}, step_time
# check nan loss # check nan loss
if torch.isnan(loss_dict["loss"]).any(): if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
# set gradient clipping threshold # set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None: if "grad_clip" in config and config.grad_clip is not None:
@ -463,7 +474,6 @@ class Trainer:
else: else:
grad_clip = 0.0 # meaning no gradient clipping grad_clip = 0.0 # meaning no gradient clipping
# TODO: compute grad norm
if grad_clip <= 0: if grad_clip <= 0:
grad_norm = 0 grad_norm = 0
@ -474,15 +484,17 @@ class Trainer:
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
grad_clip,
) )
else: else:
# model optimizer step in mixed precision mode # model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward() scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
if grad_clip > 0: if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
scale_prev = scaler.get_scale() scale_prev = scaler.get_scale()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -491,13 +503,13 @@ class Trainer:
# main model optimizer step # main model optimizer step
loss_dict["loss"].backward() loss_dict["loss"].backward()
if grad_clip > 0: if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
optimizer.step() optimizer.step()
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
# setup lr # setup lr
if scheduler is not None and update_lr_scheduler: if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
scheduler.step() scheduler.step()
# detach losses # detach losses
@ -505,7 +517,9 @@ class Trainer:
if optimizer_idx is not None: if optimizer_idx is not None:
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
return outputs, loss_dict, step_time, grad_norm else:
loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time
@staticmethod @staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict: def _detach_loss_dict(loss_dict: Dict) -> Dict:
@ -544,11 +558,10 @@ class Trainer:
# conteainers to hold model outputs and losses for each optimizer. # conteainers to hold model outputs and losses for each optimizer.
outputs_per_optimizer = None outputs_per_optimizer = None
log_dict = {}
loss_dict = {} loss_dict = {}
if not isinstance(self.optimizer, list): if not isinstance(self.optimizer, list):
# training with a single optimizer # training with a single optimizer
outputs, loss_dict_new, step_time, grad_norm = self._optimize( outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
) )
loss_dict.update(loss_dict_new) loss_dict.update(loss_dict_new)
@ -560,25 +573,36 @@ class Trainer:
criterion = self.criterion criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None scaler = self.scaler[idx] if self.use_amp_scaler else None
scheduler = self.scheduler[idx] scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time, grad_norm = self._optimize( outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
) )
# skip the rest if the model returns None # skip the rest if the model returns None
total_step_time += step_time total_step_time += step_time
outputs_per_optimizer[idx] = outputs outputs_per_optimizer[idx] = outputs
# merge loss_dicts from each optimizer
# rename duplicates with the optimizer idx
# if None, model skipped this optimizer # if None, model skipped this optimizer
if loss_dict_new is not None: if loss_dict_new is not None:
loss_dict.update(loss_dict_new) for k, v in loss_dict_new.items():
if k in loss_dict:
loss_dict[f"{k}-{idx}"] = v
else:
loss_dict[k] = v
step_time = total_step_time
outputs = outputs_per_optimizer outputs = outputs_per_optimizer
# update avg stats # update avg runtime stats
keep_avg_update = dict() keep_avg_update = dict()
for key, value in log_dict.items():
keep_avg_update["avg_" + key] = value
keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update) self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
# print training progress # print training progress
if self.total_steps_done % self.config.print_step == 0: if self.total_steps_done % self.config.print_step == 0:
# log learning rates # log learning rates
@ -590,33 +614,27 @@ class Trainer:
else: else:
current_lr = self.optimizer.param_groups[0]["lr"] current_lr = self.optimizer.param_groups[0]["lr"]
lrs = {"current_lr": current_lr} lrs = {"current_lr": current_lr}
log_dict.update(lrs)
if grad_norm > 0:
log_dict.update({"grad_norm": grad_norm})
# log run-time stats # log run-time stats
log_dict.update( loss_dict.update(
{ {
"step_time": round(step_time, 4), "step_time": round(step_time, 4),
"loader_time": round(loader_time, 4), "loader_time": round(loader_time, 4),
} }
) )
self.c_logger.print_train_step( self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
) )
if self.args.rank == 0: if self.args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
# reduce TB load and don't log every step # reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0: if self.total_steps_done % self.config.plot_step == 0:
iter_stats = log_dict self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint: if self.config.checkpoint:
# checkpoint the model # checkpoint the model
model_loss = ( target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
)
save_checkpoint( save_checkpoint(
self.config, self.config,
self.model, self.model,
@ -625,8 +643,14 @@ class Trainer:
self.total_steps_done, self.total_steps_done,
self.epochs_done, self.epochs_done,
self.output_path, self.output_path,
model_loss=model_loss, model_loss=target_avg_loss,
) )
if self.total_steps_done % self.config.log_model_step == 0:
# log checkpoint as artifact
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
# training visualizations # training visualizations
figures, audios = None, None figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
@ -634,11 +658,13 @@ class Trainer:
elif hasattr(self.model, "train_log"): elif hasattr(self.model, "train_log"):
figures, audios = self.model.train_log(self.ap, batch, outputs) figures, audios = self.model.train_log(self.ap, batch, outputs)
if figures is not None: if figures is not None:
self.tb_logger.tb_train_figures(self.total_steps_done, figures) self.dashboard_logger.train_figures(self.total_steps_done, figures)
if audios is not None: if audios is not None:
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate) self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.total_steps_done += 1 self.total_steps_done += 1
self.callbacks.on_train_step_end() self.callbacks.on_train_step_end()
self.dashboard_logger.flush()
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self) -> None: def train_epoch(self) -> None:
@ -663,9 +689,17 @@ class Trainer:
if self.args.rank == 0: if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time} epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values) epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats: if self.config.model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done) self.logger.model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
if scheduler is not None:
scheduler.step()
else:
self.scheduler.step()
@staticmethod @staticmethod
def _model_eval_step( def _model_eval_step(
@ -701,19 +735,22 @@ class Trainer:
Tuple[Dict, Dict]: Model outputs and losses. Tuple[Dict, Dict]: Model outputs and losses.
""" """
with torch.no_grad(): with torch.no_grad():
outputs_per_optimizer = None outputs = []
loss_dict = {} loss_dict = {}
if not isinstance(self.optimizer, list): if not isinstance(self.optimizer, list):
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
else: else:
outputs_per_optimizer = [None] * len(self.optimizer) outputs = [None] * len(self.optimizer)
for idx, _ in enumerate(self.optimizer): for idx, _ in enumerate(self.optimizer):
criterion = self.criterion criterion = self.criterion
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs_per_optimizer[idx] = outputs outputs[idx] = outputs_
if loss_dict_new is not None: if loss_dict_new is not None:
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
loss_dict.update(loss_dict_new) loss_dict.update(loss_dict_new)
outputs = outputs_per_optimizer
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats # update avg stats
update_eval_values = dict() update_eval_values = dict()
@ -755,28 +792,35 @@ class Trainer:
elif hasattr(self.model, "eval_log"): elif hasattr(self.model, "eval_log"):
figures, audios = self.model.eval_log(self.ap, batch, outputs) figures, audios = self.model.eval_log(self.ap, batch, outputs)
if figures is not None: if figures is not None:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.dashboard_logger.eval_figures(self.total_steps_done, figures)
if audios is not None: if audios is not None:
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None: def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
self.data_eval,
verbose=True,
)
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None) figures, audios = self.model.test_run(self.ap, samples, None)
else: else:
figures, audios = self.model.test_run(self.ap) figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None: def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs.""" """🏃 train -> evaluate -> test for the number of epochs."""
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
@ -802,10 +846,13 @@ class Trainer:
"""Where the ✨magic✨ happens...""" """Where the ✨magic✨ happens..."""
try: try:
self._fit() self._fit()
self.dashboard_logger.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt() self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run. # if the output folder is empty remove the run.
remove_experiment_folder(self.output_path) remove_experiment_folder(self.output_path)
# finish the wandb run and sync data
self.dashboard_logger.finish()
# stop without error signal # stop without error signal
try: try:
sys.exit(0) sys.exit(0)
@ -816,10 +863,33 @@ class Trainer:
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None: def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous.""" """Save the best model. It only saves if the current target loss is smaller then the previous."""
# set the target loss to choose the best model
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
# save the model and update the best_loss
self.best_loss = save_best_model( self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], target_loss_dict,
self.best_loss, self.best_loss,
self.config, self.config,
self.model, self.model,
@ -834,9 +904,16 @@ class Trainer:
@staticmethod @staticmethod
def _setup_logger_config(log_file: str) -> None: def _setup_logger_config(log_file: str) -> None:
logging.basicConfig( handlers = [logging.StreamHandler()]
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
) # Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
@staticmethod @staticmethod
def _is_apex_available() -> bool: def _is_apex_available() -> bool:
@ -920,28 +997,33 @@ class Trainer:
return criterion return criterion
def init_arguments(): def getarguments():
train_config = TrainingArgs() train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="") parser = train_config.init_argparse(arg_prefix="")
return parser return parser
def get_last_checkpoint(path): def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path. """Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`. `(checkpoint|best_model)_([0-9]+)`.
Args: Args:
path (list): Path to files to be compared. path: Path to files to be compared.
Raises: Raises:
ValueError: If no checkpoint or best_model files are found. ValueError: If no checkpoint or best_model files are found.
Returns: Returns:
last_checkpoint (str): Last checkpoint filename. Path to the last checkpoint
Path to best checkpoint
""" """
file_names = glob.glob(os.path.join(path, "*.pth.tar")) fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {} last_models = {}
last_model_nums = {} last_model_nums = {}
for key in ["checkpoint", "best_model"]: for key in ["checkpoint", "best_model"]:
@ -963,7 +1045,7 @@ def get_last_checkpoint(path):
key_file_names = [fn for fn in file_names if key in fn] key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0: if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime) last_model = max(key_file_names, key=os.path.getctime)
last_model_num = torch.load(last_model)["step"] last_model_num = load_fsspec(last_model)["step"]
if last_model is not None: if last_model is not None:
last_models[key] = last_model last_models[key] = last_model
@ -997,8 +1079,8 @@ def process_args(args, config=None):
audio_path (str): Path to save generated test audios. audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console. logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard logging. dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO: TODO:
- Interactive config definition. - Interactive config definition.
@ -1012,6 +1094,7 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path) args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path: if not args.best_path:
args.best_path = best_model args.best_path = best_model
# init config if not already defined # init config if not already defined
if config is None: if config is None:
if args.config_path: if args.config_path:
@ -1030,12 +1113,12 @@ def process_args(args, config=None):
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name) experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None dashboard_logger = None
if args.rank == 0: if args.rank == 0:
os.makedirs(audio_path, exist_ok=True)
new_fields = {} new_fields = {}
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
@ -1043,17 +1126,20 @@ def process_args(args, config=None):
# if model characters are not set in the config file # if model characters are not set in the config file
# save the default set to the config file for future # save the default set to the config file for future
# compatibility. # compatibility.
if config.has("characters_config"): if config.has("characters") and config.characters is None:
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775) dashboard_logger = init_logger(config)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def init_training(argv: Union[List, Coqpit], config: Coqpit = None): def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
@ -1063,5 +1149,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
else: else:
parser = init_arguments() parser = init_arguments()
args = parser.parse_known_args() args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config) config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger

View File

@ -13,12 +13,16 @@ class GSTConfig(Coqpit):
Args: Args:
gst_style_input_wav (str): gst_style_input_wav (str):
Path to the wav file used to define the style of the output speech at inference. Defaults to None. Path to the wav file used to define the style of the output speech at inference. Defaults to None.
gst_style_input_weights (dict): gst_style_input_weights (dict):
Defines the weights for each style token used at inference. Defaults to None. Defines the weights for each style token used at inference. Defaults to None.
gst_embedding_dim (int): gst_embedding_dim (int):
Defines the size of the GST embedding vector dimensions. Defaults to 256. Defines the size of the GST embedding vector dimensions. Defaults to 256.
gst_num_heads (int): gst_num_heads (int):
Number of attention heads used by the multi-head attention. Defaults to 4. Number of attention heads used by the multi-head attention. Defaults to 4.
gst_num_style_tokens (int): gst_num_style_tokens (int):
Number of style token vectors. Defaults to 10. Number of style token vectors. Defaults to 10.
""" """
@ -51,17 +55,23 @@ class CharactersConfig(Coqpit):
Args: Args:
pad (str): pad (str):
characters in place of empty padding. Defaults to None. characters in place of empty padding. Defaults to None.
eos (str): eos (str):
characters showing the end of a sentence. Defaults to None. characters showing the end of a sentence. Defaults to None.
bos (str): bos (str):
characters showing the beginning of a sentence. Defaults to None. characters showing the beginning of a sentence. Defaults to None.
characters (str): characters (str):
character set used by the model. Characters not in this list are ignored when converting input text to character set used by the model. Characters not in this list are ignored when converting input text to
a list of sequence IDs. Defaults to None. a list of sequence IDs. Defaults to None.
punctuations (str): punctuations (str):
characters considered as punctuation as parsing the input sentence. Defaults to None. characters considered as punctuation as parsing the input sentence. Defaults to None.
phonemes (str): phonemes (str):
characters considered as parsing phonemes. Defaults to None. characters considered as parsing phonemes. Defaults to None.
unique (bool): unique (bool):
remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old
models trained with character lists with duplicates. models trained with character lists with duplicates.
@ -95,54 +105,78 @@ class BaseTTSConfig(BaseTrainingConfig):
Args: Args:
audio (BaseAudioConfig): audio (BaseAudioConfig):
Audio processor config object instance. Audio processor config object instance.
use_phonemes (bool): use_phonemes (bool):
enable / disable phoneme use. enable / disable phoneme use.
use_espeak_phonemes (bool): use_espeak_phonemes (bool):
enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`). enable / disable eSpeak-compatible phonemes (only if use_phonemes = `True`).
compute_input_seq_cache (bool): compute_input_seq_cache (bool):
enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of
the training, It allows faster data loader time and precise limitation with `max_seq_len` and the training, It allows faster data loader time and precise limitation with `max_seq_len` and
`min_seq_len`. `min_seq_len`.
text_cleaner (str): text_cleaner (str):
Name of the text cleaner used for cleaning and formatting transcripts. Name of the text cleaner used for cleaning and formatting transcripts.
enable_eos_bos_chars (bool): enable_eos_bos_chars (bool):
enable / disable the use of eos and bos characters. enable / disable the use of eos and bos characters.
test_senteces_file (str): test_senteces_file (str):
Path to a txt file that has sentences used at test time. The file must have a sentence per line. Path to a txt file that has sentences used at test time. The file must have a sentence per line.
phoneme_cache_path (str): phoneme_cache_path (str):
Path to the output folder caching the computed phonemes for each sample. Path to the output folder caching the computed phonemes for each sample.
characters (CharactersConfig): characters (CharactersConfig):
Instance of a CharactersConfig class. Instance of a CharactersConfig class.
batch_group_size (int): batch_group_size (int):
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
prevent using the same batches for each epoch. prevent using the same batches for each epoch.
loss_masking (bool): loss_masking (bool):
enable / disable masking loss values against padded segments of samples in a batch. enable / disable masking loss values against padded segments of samples in a batch.
min_seq_len (int): min_seq_len (int):
Minimum input sequence length to be used at training. Minimum input sequence length to be used at training.
max_seq_len (int): max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
compute_f0 (int): compute_f0 (int):
(Not in use yet). (Not in use yet).
compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.
use_noise_augment (bool): use_noise_augment (bool):
Augment the input audio with random noise. Augment the input audio with random noise.
add_blank (bool): add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence. of slower run-time due to the longer input sequence.
datasets (List[BaseDatasetConfig]): datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training. for training.
optimizer (str): optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``. Defaults to ``.
optimizer_params (dict): optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str): lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``. `TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict): lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
test_sentences (List[str]): test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]' List of sentences to be used at testing. Defaults to '[]'
""" """
@ -166,6 +200,7 @@ class BaseTTSConfig(BaseTrainingConfig):
min_seq_len: int = 1 min_seq_len: int = 1
max_seq_len: int = float("inf") max_seq_len: int = float("inf")
compute_f0: bool = False compute_f0: bool = False
compute_linear_spec: bool = False
use_noise_augment: bool = False use_noise_augment: bool = False
add_blank: bool = False add_blank: bool = False
# dataset # dataset

View File

@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs
@dataclass
class VitsConfig(BaseTTSConfig):
"""Defines parameters for VITS End2End TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.
lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.
lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.
kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.
disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.
gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.
feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.
mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
min_seq_len (int):
Minimum text length to be considered for training. Defaults to `13`.
max_seq_len (int):
Maximum text length to be considered for training. Defaults to `500`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[str]):
List of sentences to be used for testing.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs import VitsConfig
>>> config = VitsConfig()
"""
model: str = "vits"
# model specific params
model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 13
max_seq_len: int = 500
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -23,7 +23,9 @@ class TTSDataset(Dataset):
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[List], meta_data: List[List],
characters: Dict = None, characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False, add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_seq_len: int = 0, min_seq_len: int = 0,
max_seq_len: int = float("inf"), max_seq_len: int = float("inf"),
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
characters (dict): `dict` of custom text characters used for converting texts to sequences. characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false. models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0. batch. Set 0 to disable. Defaults to 0.
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.characters = characters self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank self.add_blank = add_blank
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False self.input_seq_computed = False
self.rescue_item_idx = 1
if use_phonemes and not os.path.isdir(phoneme_cache_path): if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True) os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose: if self.verbose:
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
return data return data
@staticmethod @staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank): def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence( phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
) )
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
@staticmethod @staticmethod
def _load_or_generate_phoneme_sequence( def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
): ):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank text, cache_path, cleaners, language, custom_symbols, characters, add_blank
) )
except (ValueError, IOError): except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank text, cache_path, cleaners, language, custom_symbols, characters, add_blank
) )
if enable_eos_bos: if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters) phonemes = pad_with_eos_bos(phonemes, tp=characters)
@ -173,6 +191,7 @@ class TTSDataset(Dataset):
else: else:
text, wav_file, speaker_name = item text, wav_file, speaker_name = item
attn = None attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
@ -189,13 +208,19 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
) )
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
@ -209,9 +234,10 @@ class TTSDataset(Dataset):
# return a different sample if the phonemized # return a different sample if the phonemized
# text is longer than the threshold # text is longer than the threshold
# TODO: find a better fix # TODO: find a better fix
return self.load_data(100) return self.load_data(self.rescue_item_idx)
sample = { sample = {
"raw_text": raw_text,
"text": text, "text": text,
"wav": wav, "wav": wav,
"attn": attn, "attn": attn,
@ -238,7 +264,13 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item text, *_ = item
sequence = np.asarray( sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
self.items[idx][0] = sequence self.items[idx][0] = sequence
@ -249,6 +281,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
] ]
@ -329,6 +362,7 @@ class TTSDataset(Dataset):
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing] text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors # get pre-computed d-vectors
@ -347,6 +381,14 @@ class TTSDataset(Dataset):
mel_lengths = [m.shape[1] for m in mel] mel_lengths = [m.shape[1] for m in mel]
# lengths adjusted by the reduction factor
mel_lengths_adjusted = [
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
if m.shape[1] % self.outputs_per_step
else m.shape[1]
for m in mel
]
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
@ -385,6 +427,20 @@ class TTSDataset(Dataset):
else: else:
linear = None linear = None
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# collate attention alignments # collate attention alignments
if batch[0]["attn"] is not None: if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
@ -397,6 +453,7 @@ class TTSDataset(Dataset):
attns = torch.FloatTensor(attns).unsqueeze(1) attns = torch.FloatTensor(attns).unsqueeze(1)
else: else:
attns = None attns = None
# TODO: return dictionary
return ( return (
text, text,
text_lenghts, text_lenghts,
@ -409,6 +466,8 @@ class TTSDataset(Dataset):
d_vectors, d_vectors,
speaker_ids, speaker_ids,
attns, attns,
wav_padded,
raw_text,
) )
raise TypeError( raise TypeError(

View File

@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
return x return x
class LayerNorm2(nn.Module):
"""Layer norm for the 2nd dimension of the input using torch primitive.
Args:
channels (int): number of channels (2nd dimension) of the input.
eps (float): to prevent 0 division
Shapes:
- input: (B, C, T)
- output: (B, C, T)
"""
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TemporalBatchNorm1d(nn.BatchNorm1d): class TemporalBatchNorm1d(nn.BatchNorm1d):
"""Normalize each channel separately over time and batch.""" """Normalize each channel separately over time and batch."""

View File

@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer. dropout_p (float): Dropout rate used after each conv layer.
""" """
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
super().__init__() super().__init__()
# class arguments # class arguments
self.in_channels = in_channels self.in_channels = in_channels
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
self.norm_2 = LayerNorm(hidden_channels) self.norm_2 = LayerNorm(hidden_channels)
# output layer # output layer
self.proj = nn.Conv1d(hidden_channels, 1, 1) self.proj = nn.Conv1d(hidden_channels, 1, 1)
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
def forward(self, x, x_mask): def forward(self, x, x_mask, g=None):
""" """
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]` - x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
""" """
if g is not None:
x = x + self.cond(g)
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
x = torch.relu(x) x = torch.relu(x)
x = self.norm_1(x) x = self.norm_1(x)

View File

@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
:: ::
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------| |---------------> conv1d_1x1 ------------------|
Args: Args:
in_channels (int): number of input tensor channels. in_channels (int): number of input tensor channels.

View File

@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.layers.glow_tts.glow import LayerNorm from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
class RelativePositionMultiHeadAttention(nn.Module): class RelativePositionMultiHeadAttention(nn.Module):
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
dropout_p (float, optional): dropout rate. Defaults to 0. dropout_p (float, optional): dropout rate. Defaults to 0.
""" """
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0): def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) if causal:
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2) self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
self.dropout = nn.Dropout(dropout_p) self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask): def forward(self, x, x_mask):
x = self.conv_1(x * x_mask) x = self.conv_1(self.padding(x * x_mask))
x = torch.relu(x) x = torch.relu(x)
x = self.dropout(x) x = self.dropout(x)
x = self.conv_2(x * x_mask) x = self.conv_2(self.padding(x * x_mask))
return x * x_mask return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module): class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding. """Transformer with Relative Potional Encoding.
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
If default, relative encoding is disabled and it is a regular transformer. If default, relative encoding is disabled and it is a regular transformer.
Defaults to None. Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None. input_length (int, optional): input lenght to limit position encoding. Defaults to None.
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
""" """
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
hidden_channels, hidden_channels: int,
hidden_channels_ffn, hidden_channels_ffn: int,
num_heads, num_heads: int,
num_layers, num_layers: int,
kernel_size=1, kernel_size=1,
dropout_p=0.0, dropout_p=0.0,
rel_attn_window_size=None, rel_attn_window_size: int = None,
input_length=None, input_length: int = None,
layer_norm_type: str = "1",
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
input_length=input_length, input_length=input_length,
) )
) )
self.norm_layers_1.append(LayerNorm(hidden_channels)) if layer_norm_type == "1":
self.norm_layers_1.append(LayerNorm(hidden_channels))
elif layer_norm_type == "2":
self.norm_layers_1.append(LayerNorm2(hidden_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
if hidden_channels != out_channels and (idx + 1) == self.num_layers: if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
) )
) )
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) if layer_norm_type == "1":
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
elif layer_norm_type == "2":
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
def forward(self, x, x_mask): def forward(self, x, x_mask):
""" """

View File

@ -2,11 +2,13 @@ import math
import numpy as np import numpy as np
import torch import torch
from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn import functional from torch.nn import functional
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.ssim import ssim from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
# pylint: disable=abstract-method # pylint: disable=abstract-method
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
+ self.mdn_alpha * mdn_loss + self.mdn_alpha * mdn_loss
) )
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
c.audio.win_length,
sample_rate=c.audio.sample_rate,
mel_fmin=c.audio.mel_fmin,
mel_fmax=c.audio.mel_fmax,
n_mels=c.audio.num_mels,
use_mel=True,
do_amp_to_db=True,
)
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
@staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def forward(
self,
waveform,
waveform_hat,
z_p,
logs_q,
m_p,
logs_p,
z_len,
scores_disc_fake,
feats_disc_fake,
feats_disc_real,
):
"""
Shapes:
- wavefrom: :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]`
- logs_p: :math:`[B, C, T]`
- z_len: :math:`[B]`
- scores_disc_fake[i]: :math:`[B, C]`
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
"""
loss = 0.0
return_dict = {}
z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel
return_dict["loss"] = loss
return return_dict
class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha
@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc
return_dict["loss_disc"] = loss_disc
return_dict["loss"] = loss
return return_dict

View File

@ -8,10 +8,10 @@ class GST(nn.Module):
See https://arxiv.org/pdf/1803.09017""" See https://arxiv.org/pdf/1803.09017"""
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None):
super().__init__() super().__init__()
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim) self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim)
def forward(self, inputs, speaker_embedding=None): def forward(self, inputs, speaker_embedding=None):
enc_out = self.encoder(inputs) enc_out = self.encoder(inputs)
@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module):
class StyleTokenLayer(nn.Module): class StyleTokenLayer(nn.Module):
"""NN Module attending to style tokens based on prosody encodings.""" """NN Module attending to style tokens based on prosody encodings."""
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None): def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
super().__init__() super().__init__()
self.query_dim = embedding_dim // 2 self.query_dim = gst_embedding_dim // 2
if d_vector_dim: if d_vector_dim:
self.query_dim += d_vector_dim self.query_dim += d_vector_dim
self.key_dim = embedding_dim // num_heads self.key_dim = gst_embedding_dim // num_heads
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
nn.init.normal_(self.style_tokens, mean=0, std=0.5) nn.init.normal_(self.style_tokens, mean=0, std=0.5)
self.attention = MultiHeadAttention( self.attention = MultiHeadAttention(
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads
) )
def forward(self, inputs): def forward(self, inputs):

View File

@ -0,0 +1,77 @@
import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class VitsDiscriminator(nn.Module):
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
::
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
score_sd, feats_sd = self.sd(x)
return scores + [score_sd], feats + [feats_sd]

View File

@ -0,0 +1,271 @@
import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.glow import WN
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.data import sequence_mask
LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size: int,
dropout_p: float,
):
"""Text Encoder for VITS model.
Args:
n_vocab (int): Number of characters for the embedding layer.
out_channels (int): Number of channels for the output.
hidden_channels (int): Number of channels for the hidden layers.
hidden_channels_ffn (int): Number of channels for the convolutional layers.
num_heads (int): Number of attention heads for the Transformer layers.
num_layers (int): Number of Transformer layers.
kernel_size (int): Kernel size for the FFN layers in Transformer network.
dropout_p (float): Dropout rate for the Transformer layers.
"""
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
num_heads=num_heads,
num_layers=num_layers,
kernel_size=kernel_size,
dropout_p=dropout_p,
layer_norm_type="2",
rel_attn_window_size=4,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=0,
cond_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.half_channels = channels // 2
self.mean_only = mean_only
# input layer
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
# coupling layers
self.enc = WN(
hidden_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=dropout_p,
c_in_channels=cond_channels,
)
# output layer
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
log_scale = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(log_scale) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(log_scale, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ResidualCouplingBlocks(nn.Module):
def __init__(
self,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
num_flows=4,
cond_channels=0,
):
"""Redisual Coupling blocks for VITS flow layers.
Args:
channels (int): Number of input and output tensor channels.
hidden_channels (int): Number of hidden network channels.
kernel_size (int): Kernel size of the WaveNet layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
"""
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.num_flows = num_flows
self.cond_channels = cond_channels
self.flows = nn.ModuleList()
for _ in range(num_flows):
self.flows.append(
ResidualCouplingBlock(
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
cond_channels=cond_channels,
mean_only=True,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
x = torch.flip(x, [1])
else:
for flow in reversed(self.flows):
x = torch.flip(x, [1])
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
cond_channels=0,
):
"""Posterior Encoder of VITS model.
::
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of output tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of the WaveNet convolution layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.cond_channels = cond_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B, 1]`
- g: :math:`[B, C, 1]`
"""
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask

View File

@ -0,0 +1,276 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
class DilatedDepthSeparableConv(nn.Module):
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
"""Dilated Depth-wise Separable Convolution module.
::
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|-------------------------------------------------------------------------------------^
Args:
channels ([type]): [description]
kernel_size ([type]): [description]
num_layers ([type]): [description]
dropout_p (float, optional): [description]. Defaults to 0.0.
Returns:
torch.tensor: Network output masked by the input sequence mask.
"""
super().__init__()
self.num_layers = num_layers
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(num_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm2(channels))
self.norms_2.append(LayerNorm2(channels))
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
if g is not None:
x = x + g
for i in range(self.num_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.dropout(y)
x = x + y
return x * x_mask
class ElementwiseAffine(nn.Module):
"""Element-wise affine transform like no-population stats BatchNorm alternative.
Args:
channels (int): Number of input tensor channels.
"""
def __init__(self, channels):
super().__init__()
self.translation = nn.Parameter(torch.zeros(channels, 1))
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
if not reverse:
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
return y, logdet
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
return x
class ConvFlow(nn.Module):
"""Dilated depth separable convolutional based spline flow.
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of in network channels.
kernel_size (int): Convolutional kernel size.
num_layers (int): Number of convolutional layers.
num_bins (int, optional): Number of spline bins. Defaults to 10.
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
num_layers: int,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.num_bins = num_bins
self.tail_bound = tail_bound
self.hidden_channels = hidden_channels
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
return x
class StochasticDurationPredictor(nn.Module):
"""Stochastic duration predictor with Spline Flows.
It applies Variational Dequantization and Variationsl Data Augmentation.
Paper:
SDP: https://arxiv.org/pdf/2106.06103.pdf
Spline Flow: https://arxiv.org/abs/1906.04032
::
## Inference
x -> TextCondEncoder() -> Flow() -> dr_hat
noise ----------------------^
## Training
|---------------------|
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
d -> DurCondEncoder() -> ^ |
|------------------------------------------------------------------------------|
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of convolutional layers.
dropout_p (float): Dropout rate.
num_flows (int, optional): Number of flow blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
):
super().__init__()
# condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# posterior encoder
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
# condition encoder duration
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# flow layers
self.post_flows = nn.ModuleList()
self.post_flows.append(ElementwiseAffine(2))
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- dr: :math:`[B, 1, T]`
- g: :math:`[B, C]`
"""
# condition encoder text
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert dr is not None
# condition encoder duration
h = self.post_pre(dr)
h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise
# posterior encoder
logdet_tot_q = 0.0
for idx, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
logdet_tot_q = logdet_tot_q + logdet_q
if idx > 0:
z_q = torch.flip(z_q, [1])
z_u, z_v = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (dr - u) * x_mask
# posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
)
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
logdet_tot = torch.sum(-z0, [1, 2])
z = torch.cat([z0, z_v], 1)
# flow layers
for idx, flow in enumerate(flows):
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
if idx > 0:
z = torch.flip(z, [1])
# flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = torch.flip(z, [1])
z = flow(z, x_mask, g=x, reverse=reverse)
z0, _ = torch.split(z, [1, 1], 1)
logw = z0
return logw

View File

@ -0,0 +1,203 @@
# adopted from https://github.com/bayesiains/nflows
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs,
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -4,20 +4,23 @@ from TTS.utils.generic_utils import find_module
def setup_model(config): def setup_model(config):
print(" > Using model: {}".format(config.model)) print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower()) MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model # define set of characters used by the model
if config.characters is not None: if config.characters is not None:
# set characters from config # set characters from config
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name if hasattr(MyModel, "make_symbols"):
symbols = MyModel.make_symbols(config)
else:
symbols, phonemes = make_symbols(**config.characters)
else: else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
if config.use_phonemes:
symbols = phonemes
# use default characters and assign them to config # use default characters and assign them to config
config.characters = parse_symbols() config.characters = parse_symbols()
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
# consider special `blank` character if `add_blank` is set True # consider special `blank` character if `add_blank` is set True
num_chars = num_chars + getattr(config, "add_blank", False) num_chars = len(symbols) + getattr(config, "add_blank", False)
config.num_chars = num_chars config.num_chars = num_chars
# compatibility fix # compatibility fix
if "model_params" in config: if "model_params" in config:

View File

@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass @dataclass
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
@ -75,9 +76,6 @@ class BaseTacotron(BaseTTS):
self.decoder_backward = None self.decoder_backward = None
self.coarse_decoder = None self.coarse_decoder = None
# init multi-speaker layers
self.init_multispeaker(config)
@staticmethod @staticmethod
def _format_aux_input(aux_input: Dict) -> Dict: def _format_aux_input(aux_input: Dict) -> Dict:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
@ -113,7 +111,7 @@ class BaseTacotron(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if "r" in state: if "r" in state:
self.decoder.set_r(state["r"]) self.decoder.set_r(state["r"])
@ -236,6 +234,7 @@ class BaseTacotron(BaseTTS):
def compute_gst(self, inputs, style_input, speaker_embedding=None): def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token""" """Compute global style token"""
if isinstance(style_input, dict): if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None: if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
@ -247,8 +246,10 @@ class BaseTacotron(BaseTTS):
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None: elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else: else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs) inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs return inputs

View File

@ -1,6 +1,6 @@
import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
return get_speaker_manager(config, restore_path, data, out_path) return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
or with external `d_vectors` computed from a speaker encoder model. `in_channels` size of the connected layers.
If you need a different behaviour, override this function for your model. This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.0
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
""" """
# init speaker manager # init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data) self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if config.use_speaker_embedding and not config.use_d_vector_file: if data is not None or self.speaker_manager.speaker_ids:
self.num_speakers = self.speaker_manager.num_speakers
else:
self.num_speakers = (
config.num_speakers
if "num_speakers" in config and config.num_speakers != 0
else self.speaker_manager.num_speakers
)
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = ( self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
) )
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
text_input = batch[0] text_input = batch[0]
text_lengths = batch[1] text_lengths = batch[1]
speaker_names = batch[2] speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None linear_input = batch[3]
mel_input = batch[4] mel_input = batch[4]
mel_lengths = batch[5] mel_lengths = batch[5]
stop_targets = batch[6] stop_targets = batch[6]
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
d_vectors = batch[8] d_vectors = batch[8]
speaker_ids = batch[9] speaker_ids = batch[9]
attn_mask = batch[10] attn_mask = batch[10]
waveform = batch[11]
max_text_length = torch.max(text_lengths.float()) max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float()) max_spec_length = torch.max(mel_lengths.float())
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
"max_text_length": float(max_text_length), "max_text_length": float(max_text_length),
"max_spec_length": float(max_spec_length), "max_spec_length": float(max_spec_length),
"item_idx": item_idx, "item_idx": item_idx,
"waveform": waveform,
} }
def get_data_loader( def get_data_loader(
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None speaker_id_mapping = None
d_vector_mapping = None d_vector_mapping = None
# setup custom symbols if needed
custom_symbols = None
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataloader # init dataloader
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1, outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner, text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron", compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
meta_data=data_items, meta_data=data_items,
ap=ap, ap=ap,
characters=config.characters, characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"], add_blank=config["add_blank"],
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len, min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
@ -185,8 +213,21 @@ class BaseTTS(BaseModel):
) )
if config.use_phonemes and config.compute_input_seq_cache: if config.use_phonemes and config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. if hasattr(self, "eval_data_items") and is_eval:
dataset.compute_input_seq(config.num_loader_workers) dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
@ -216,7 +257,7 @@ class BaseTTS(BaseModel):
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input() aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis( outputs_dict = synthesis(
self, self,
sen, sen,
self.config, self.config,
@ -228,9 +269,12 @@ class BaseTTS(BaseModel):
enable_eos_bos_chars=self.config.enable_eos_bos_chars, enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
).values() )
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_audios["{}-audio".format(idx)] = wav test_figures["{}-prediction".format(idx)] = plot_spectrogram(
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False) outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) )
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return test_figures, test_audios return test_figures, test_audios

49
TTS/tts/models/glow_tts.py Executable file → Normal file
View File

@ -12,14 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
class GlowTTS(BaseTTS): class GlowTTS(BaseTTS):
"""Glow TTS models from https://arxiv.org/abs/2005.11129 """GlowTTS model.
Paper abstract: Paper::
https://arxiv.org/abs/2005.11129
Paper abstract::
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
@ -144,7 +149,6 @@ class GlowTTS(BaseTTS):
g = F.normalize(g).unsqueeze(-1) g = F.normalize(g).unsqueeze(-1)
else: else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
@ -361,12 +365,49 @@ class GlowTTS(BaseTTS):
train_audio = ap.inv_melspectrogram(pred_spec.T) train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio} return figures, {"audio": train_audio}
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs) return self.train_log(ap, batch, outputs)
@torch.no_grad()
def test_run(self, ap):
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
outputs = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs["outputs"]["model_outputs"], ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios
def preprocess(self, y, y_lengths, y_max_length, attn=None): def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None: if y_max_length is not None:
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
@ -382,7 +423,7 @@ class GlowTTS(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass @dataclass
@ -105,7 +106,7 @@ class SpeedySpeech(BaseTTS):
if isinstance(config.model_args.length_scale, int) if isinstance(config.model_args.length_scale, int)
else config.model_args.length_scale else config.model_args.length_scale
) )
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder( self.encoder = Encoder(
config.model_args.hidden_channels, config.model_args.hidden_channels,
config.model_args.hidden_channels, config.model_args.hidden_channels,
@ -227,6 +228,7 @@ class SpeedySpeech(BaseTTS):
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn} outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
return outputs return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
@ -306,7 +308,7 @@ class SpeedySpeech(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -23,19 +23,19 @@ class Tacotron(BaseTacotron):
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit):
super().__init__(config) super().__init__(config)
self.num_chars, self.config = self.get_characters(config) chars, self.config = self.get_characters(config)
config.num_chars = self.num_chars = len(chars)
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
# speaker embedding layer # set speaker embedding channel size for determining `in_channels` for the connected layers.
if self.num_speakers > 1: # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config) self.init_multispeaker(config)
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
if self.use_gst: if self.use_gst:
@ -75,13 +75,11 @@ class Tacotron(BaseTacotron):
if self.gst and self.use_gst: if self.gst and self.use_gst:
self.gst_layer = GST( self.gst_layer = GST(
num_mel=self.decoder_output_dim, num_mel=self.decoder_output_dim,
d_vector_dim=self.d_vector_dim
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
else None,
num_heads=self.gst.gst_num_heads, num_heads=self.gst.gst_num_heads,
num_style_tokens=self.gst.gst_num_style_tokens, num_style_tokens=self.gst.gst_num_style_tokens,
gst_embedding_dim=self.gst.gst_embedding_dim, gst_embedding_dim=self.gst.gst_embedding_dim,
) )
# backward pass decoder # backward pass decoder
if self.bidirectional_decoder: if self.bidirectional_decoder:
self._init_backward_decoder() self._init_backward_decoder()
@ -106,7 +104,9 @@ class Tacotron(BaseTacotron):
self.max_decoder_steps, self.max_decoder_steps,
) )
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None): def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
""" """
Shapes: Shapes:
text: [B, T_in] text: [B, T_in]
@ -115,6 +115,7 @@ class Tacotron(BaseTacotron):
mel_lengths: [B] mel_lengths: [B]
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
""" """
aux_input = self._format_aux_input(aux_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None} outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
inputs = self.embedding(text) inputs = self.embedding(text)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
@ -125,12 +126,10 @@ class Tacotron(BaseTacotron):
# global style token # global style token
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst( encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
)
# speaker embedding # speaker embedding
if self.num_speakers > 1: if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vectors: if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
else: else:
@ -182,7 +181,7 @@ class Tacotron(BaseTacotron):
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.use_d_vectors: if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
# reshape embedded_speakers # reshape embedded_speakers

View File

@ -23,7 +23,7 @@ class Tacotron2(BaseTacotron):
super().__init__(config) super().__init__(config)
chars, self.config = self.get_characters(config) chars, self.config = self.get_characters(config)
self.num_chars = len(chars) config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
# pass all config fields to `self` # pass all config fields to `self`
@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron):
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
# speaker embedding layer # set speaker embedding channel size for determining `in_channels` for the connected layers.
if self.num_speakers > 1: # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config) self.init_multispeaker(config)
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
if self.use_gst: if self.use_gst:
@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron):
# base model layers # base model layers
self.encoder = Encoder(self.encoder_in_features) self.encoder = Encoder(self.encoder_in_features)
self.decoder = Decoder( self.decoder = Decoder(
self.decoder_in_features, self.decoder_in_features,
self.decoder_output_dim, self.decoder_output_dim,
@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron):
if self.gst and self.use_gst: if self.gst and self.use_gst:
self.gst_layer = GST( self.gst_layer = GST(
num_mel=self.decoder_output_dim, num_mel=self.decoder_output_dim,
d_vector_dim=self.d_vector_dim
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
else None,
num_heads=self.gst.gst_num_heads, num_heads=self.gst.gst_num_heads,
num_style_tokens=self.gst.gst_num_style_tokens, num_style_tokens=self.gst.gst_num_style_tokens,
gst_embedding_dim=self.gst.gst_embedding_dim, gst_embedding_dim=self.gst.gst_embedding_dim,
@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None): def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
""" """
Shapes: Shapes:
text: [B, T_in] text: [B, T_in]
@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron):
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst( encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
) if self.use_speaker_embedding or self.use_d_vector_file:
if self.num_speakers > 1: if not self.use_d_vector_file:
if not self.use_d_vectors:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
else: else:
@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron):
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.use_d_vectors: if not self.use_d_vector_file:
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
# reshape embedded_speakers # reshape embedded_speakers
if embedded_speakers.ndim == 1: if embedded_speakers.ndim == 1:

767
TTS/tts/models/vits.py Normal file
View File

@ -0,0 +1,767 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices"""
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
return segments
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
"""Create random segments based on the input lengths."""
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, ids_str, segment_size)
return ret, ids_str
@dataclass
class VitsArgs(Coqpit):
"""VITS model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 513.
spec_segment_size (int):
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
hidden_channels (int):
Number of hidden channels of the model. Defaults to 192.
hidden_channels_ffn_text_encoder (int):
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
num_heads_text_encoder (int):
Number of attention heads of the text encoder transformer. Defaults to 2.
num_layers_text_encoder (int):
Number of transformer layers in the text encoder. Defaults to 6.
kernel_size_text_encoder (int):
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
dropout_p_text_encoder (float):
Dropout rate of the text encoder. Defaults to 0.1.
dropout_p_duration_predictor (float):
Dropout rate of the duration predictor. Defaults to 0.1.
kernel_size_posterior_encoder (int):
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
dilatation_posterior_encoder (int):
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
num_layers_posterior_encoder (int):
Number of posterior encoder's WaveNet layers. Defaults to 16.
kernel_size_flow (int):
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
dilatation_flow (int):
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
num_layers_flow (int):
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
resblock_type_decoder (str):
Type of the residual block in the decoder network. Defaults to "1".
resblock_kernel_sizes_decoder (List[int]):
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
resblock_dilation_sizes_decoder (List[List[int]]):
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
upsample_rates_decoder (List[int]):
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
upsample_initial_channel_decoder (int):
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
upsample_kernel_sizes_decoder (List[int]):
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
use_sdp (int):
Use Stochastic Duration Predictor. Defaults to True.
noise_scale (float):
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
inference_noise_scale (float):
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
length_scale (int):
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
noise_scale_dp (float):
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
inference_noise_scale_dp (float):
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
max_inference_len (int):
Maximum inference length to limit the memory use. Defaults to None.
init_discriminator (bool):
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
use_spectral_norm_disriminator (bool):
Use spectral normalization over weight norm in the discriminator. Defaults to False.
use_speaker_embedding (bool):
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
speaker_embedding_channels (int):
Number of speaker embedding channels. Defaults to 256.
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
detach_dp_input (bool):
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
"""
num_chars: int = 100
out_channels: int = 513
spec_segment_size: int = 32
hidden_channels: int = 192
hidden_channels_ffn_text_encoder: int = 768
num_heads_text_encoder: int = 2
num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1
dropout_p_duration_predictor: int = 0.1
kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16
kernel_size_flow: int = 5
dilation_rate_flow: int = 1
num_layers_flow: int = 4
resblock_type_decoder: int = "1"
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
use_sdp: int = True
noise_scale: float = 1.0
inference_noise_scale: float = 0.667
length_scale: int = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0
detach_dp_input: bool = True
class Vits(BaseTTS):
"""VITS TTS model
Paper::
https://arxiv.org/pdf/2106.06103.pdf
Paper Abstract::
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
natural one-to-many relationship in which a text input can be spoken in multiple ways
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
available TTS systems and achieves a MOS comparable to ground truth.
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits(config)
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
self.END2END = True
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
else:
self.config = config
config.model_args.num_chars = config.num_chars
args = self.config.model_args
elif isinstance(config, VitsArgs):
# loading from VitsArgs
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or VitsArgs")
self.args = args
self.init_multispeaker(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
self.inference_noise_scale = args.inference_noise_scale
self.inference_noise_scale_dp = args.inference_noise_scale_dp
self.noise_scale_dp = args.noise_scale_dp
self.max_inference_len = args.max_inference_len
self.spec_segment_size = args.spec_segment_size
self.text_encoder = TextEncoder(
args.num_chars,
args.hidden_channels,
args.hidden_channels,
args.hidden_channels_ffn_text_encoder,
args.num_heads_text_encoder,
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
)
self.posterior_encoder = PosteriorEncoder(
args.out_channels,
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_posterior_encoder,
dilation_rate=args.dilation_rate_posterior_encoder,
num_layers=args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim,
)
self.flow = ResidualCouplingBlocks(
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_flow,
dilation_rate=args.dilation_rate_flow,
num_layers=args.num_layers_flow,
cond_channels=self.embedded_speaker_dim,
)
if args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels,
192,
3,
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
)
self.waveform_decoder = HifiganGenerator(
args.hidden_channels,
1,
args.resblock_type_decoder,
args.resblock_dilation_sizes_decoder,
args.resblock_kernel_sizes_decoder,
args.upsample_kernel_sizes_decoder,
args.upsample_initial_channel_decoder,
args.upsample_rates_decoder,
inference_padding=0,
cond_channels=self.embedded_speaker_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
)
if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
if hasattr(config, "model_args"):
config = config.model_args
self.embedded_speaker_dim = 0
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
self.speaker_manager.num_speakers = config.num_speakers
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
# init d-vector usage
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g = None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = aux_input["d_vectors"]
return sid, g
def forward(
self,
x: torch.tensor,
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None},
) -> Dict:
"""Forward pass of the model.
Args:
x (torch.tensor): Batch of input character sequence IDs.
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
Shapes:
- x: :math:`[B, T_seq]`
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
outputs = {}
sid, g = self._set_cond_input(aux_input)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
# speaker embedding
if self.num_speakers > 1 and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
nll_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
outputs["nll_duration"] = nll_duration
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
}
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
sid, g = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
if self.num_speakers > 0 and sid:
g = self.emb_g(sid).unsqueeze(-1)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
else:
logw = self.duration_predictor(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
"""TODO: create an end-point for voice conversion"""
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
waveform = batch["waveform"]
# generator pass
outputs = self.forward(
text_input,
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
wav_seg = segment(
waveform.transpose(1, 2),
outputs["slice_ids"] * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
self.wav_seg_disc_cache = wav_seg
outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
_, outputs["feats_disc_real"] = self.disc(wav_seg)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform=wav_seg.float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
logs_p=outputs["logs_p"].float(),
z_len=mel_lengths,
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
outputs["scores_disc_real"],
outputs["scores_disc_fake"],
)
return outputs, loss_dict
def train_log(
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
): # pylint: disable=no-self-use
"""Create visualizations and waveform examples.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
be projected onto Tensorboard.
Args:
ap (AudioProcessor): audio processor used at training.
batch (Dict): Model inputs used at the previous training step.
outputs (Dict): Model outputs generated at the previoud training step.
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
y_hat = outputs[0]["model_outputs"]
y = outputs[0]["waveform_seg"]
figures = plot_results(y_hat, y, ap, name_prefix)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {f"{name_prefix}/audio": sample_voice}
alignments = outputs[0]["alignments"]
align_img = alignments[0].data.cpu().numpy().T
figures.update(
{
"alignment": plot_alignment(align_img, output_fig=False),
}
)
return figures, audios
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs, "eval")
@torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
return test_figures, test_audios
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
self.disc.requires_grad_(False)
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
self.disc.requires_grad_(True)
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
`train_step()`"""
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
VitsDiscriminatorLoss,
VitsGeneratorLoss,
)
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
@staticmethod
def make_symbols(config):
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
whole training and inference steps."""
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
symbols = [_pad] + list(_punctuations) + list(_letters)
if config.use_phonemes:
symbols += list(_letters_ipa)
return symbols
@staticmethod
def get_characters(config: Coqpit):
if config.characters is not None:
symbols = Vits.make_symbols(config)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
if config.use_phonemes:
symbols = phonemes
num_chars = len(symbols) + getattr(config, "add_blank", False)
return symbols, config, num_chars
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training

View File

@ -2,6 +2,7 @@ import datetime
import importlib import importlib
import pickle import pickle
import fsspec
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r, "r": r,
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,6 +1,7 @@
import datetime import datetime
import pickle import pickle
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r, "r": r,
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None: if output_path is not None:
# same model binary if outputpath is provided # same model binary if outputpath is provided
with open(output_path, "wb") as f: with fsspec.open(output_path, "wb") as f:
f.write(tflite_model) f.write(tflite_model)
return None return None
return tflite_model return tflite_model

23
TTS/tts/utils/speakers.py Executable file → Normal file
View File

@ -3,6 +3,7 @@ import os
import random import random
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -84,12 +85,12 @@ class SpeakerManager:
@staticmethod @staticmethod
def _load_json(json_file_path: str) -> Dict: def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f: with fsspec.open(json_file_path, "r") as f:
return json.load(f) return json.load(f)
@staticmethod @staticmethod
def _save_json(json_file_path: str, data: dict) -> None: def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f: with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
@property @property
@ -294,9 +295,10 @@ def _set_file_path(path):
Intended to band aid the different paths returned in restored and continued training.""" Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "speakers.json") path_restore = os.path.join(os.path.dirname(path), "speakers.json")
path_continue = os.path.join(path, "speakers.json") path_continue = os.path.join(path, "speakers.json")
if os.path.exists(path_restore): fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore return path_restore
if os.path.exists(path_continue): if fs.exists(path_continue):
return path_continue return path_continue
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
json_file = out_path json_file = out_path
else: else:
json_file = _set_file_path(out_path) json_file = _set_file_path(out_path)
with open(json_file) as f: with fsspec.open(json_file, "r") as f:
return json.load(f) return json.load(f)
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present.""" """Saves speaker mapping if not yet present."""
if out_path is not None: if out_path is not None:
speakers_json_path = _set_file_path(out_path) speakers_json_path = _set_file_path(out_path)
with open(speakers_json_path, "w") as f: with fsspec.open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4) json.dump(speaker_mapping, f, indent=4)
@ -358,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif c.use_d_vector_file and c.d_vector_file: elif c.use_d_vector_file and c.d_vector_file:
# new speaker manager with external speaker embeddings. # new speaker manager with external speaker embeddings.
speaker_manager.set_d_vectors_from_file(c.d_vector_file) speaker_manager.set_d_vectors_from_file(c.d_vector_file)
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file. elif c.use_d_vector_file and not c.d_vector_file:
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
print( print(
" > Training with {} speakers: {}".format( " > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
) )
) )

View File

@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf import tensorflow as tf
def text_to_seq(text, CONFIG): def text_to_seq(text, CONFIG, custom_symbols=None):
text_cleaner = [CONFIG.text_cleaner] text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector # text ot phonemes to sequence vector
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
tp=CONFIG.characters, tp=CONFIG.characters,
add_blank=CONFIG.add_blank, add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes, use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
), ),
dtype=np.int32, dtype=np.int32,
) )
else: else:
seq = np.asarray( seq = np.asarray(
text_to_sequence( text_to_sequence(
text, text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
text_cleaner,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
), ),
dtype=np.int32, dtype=np.int32,
) )
@ -229,13 +227,16 @@ def synthesis(
""" """
# GST processing # GST processing
style_mel = None style_mel = None
custom_symbols = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict): if isinstance(style_wav, dict):
style_mel = style_wav style_mel = style_wav
else: else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text # preprocess the given text
text_inputs = text_to_seq(text, CONFIG) text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
# pass tensors to backend # pass tensors to backend
if backend == "torch": if backend == "torch":
if speaker_id is not None: if speaker_id is not None:
@ -274,15 +275,18 @@ def synthesis(
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results
wav = None wav = None
if use_griffin_lim: if hasattr(model, "END2END") and model.END2END:
wav = inv_spectrogram(model_outputs, ap, CONFIG) wav = model_outputs.squeeze(0)
# trim silence else:
if do_trim_silence: if use_griffin_lim:
wav = trim_silence(wav, ap) wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return_dict = { return_dict = {
"wav": wav, "wav": wav,
"alignments": alignments, "alignments": alignments,
"model_outputs": model_outputs,
"text_inputs": text_inputs, "text_inputs": text_inputs,
"outputs": outputs,
} }
return return_dict return return_dict

View File

@ -2,10 +2,9 @@
# adapted from https://github.com/keithito/tacotron # adapted from https://github.com/keithito/tacotron
import re import re
import unicodedata from typing import Dict, List
import gruut import gruut
from packaging import version
from TTS.tts.utils.text import cleaners from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols _symbols = symbols
_phonemes = phonemes _phonemes = phonemes
# Regular expression matching text enclosed in curly braces: # Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") _CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes # Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE) ph = ph.translate(GRUUT_TRANS_TABLE)
print(" > Phonemes: {}".format(ph))
return ph return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.") raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@ -106,13 +105,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
def phoneme_to_sequence( def phoneme_to_sequence(
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False text: str,
): cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
If `custom_symbols` is provided, it will override the default symbols.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement # pylint: disable=global-statement
global _phonemes_to_id, _phonemes global _phonemes_to_id, _phonemes
if tp:
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = [] sequence = []
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
@ -127,20 +151,22 @@ def phoneme_to_sequence(
sequence = pad_with_eos_bos(sequence, tp=tp) sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank: if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence return sequence
def sequence_to_phoneme(sequence, tp=None, add_blank=False): def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
# pylint: disable=global-statement # pylint: disable=global-statement
"""Converts a sequence of IDs back to a string""" """Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes global _id_to_phonemes, _phonemes
if add_blank: if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence)) sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = "" result = ""
if tp:
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence: for symbol_id in sequence:
if symbol_id in _id_to_phonemes: if symbol_id in _id_to_phonemes:
@ -149,27 +175,32 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
return result.replace("}{", " ") return result.replace("}{", " ")
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
If `custom_symbols` is provided, it will override the default symbols.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args: Args:
text: string to convert to a sequence text (str): string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through cleaner_names (List[str]): names of the cleaner functions to run the text through
tp: dictionary of character parameters to use a custom character set. tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns: Returns:
List of integers corresponding to the symbols in the text List[int]: List of integers corresponding to the symbols in the text
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
global _symbol_to_id, _symbols global _symbol_to_id, _symbols
if tp:
if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)} _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = [] sequence = []
# Check for curly braces and treat their contents as ARPAbet: # Check for curly braces and treat their contents as ARPAbet:
while text: while text:
m = _CURLY_RE.match(text) m = _CURLY_RE.match(text)
@ -185,16 +216,18 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
return sequence return sequence
def sequence_to_text(sequence, tp=None, add_blank=False): def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
"""Converts a sequence of IDs back to a string""" """Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement # pylint: disable=global-statement
global _id_to_symbol, _symbols global _id_to_symbol, _symbols
if add_blank: if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence)) sequence = list(filter(lambda x: x != len(_symbols), sequence))
if tp: if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)} _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = "" result = ""
for symbol_id in sequence: for symbol_id in sequence:

View File

@ -28,10 +28,10 @@ def make_symbols(
sorted(list(set(phonemes))) if unique else sorted(list(phonemes)) sorted(list(set(phonemes))) if unique else sorted(list(phonemes))
) # this is to keep previous models compatible. ) # this is to keep previous models compatible.
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ["@" + s for s in _phonemes_sorted] # _arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols: # Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
_symbols += _arpabet # _symbols += _arpabet
return _symbols, _phonemes return _symbols, _phonemes

View File

@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py""" """Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
"""
def __init__( def __init__(
self, self,
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
mel_fmax=None, mel_fmax=None,
n_mels=80, n_mels=80,
use_mel=False, use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
): ):
super().__init__() super().__init__()
self.n_fft = n_fft self.n_fft = n_fft
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_fmax = mel_fmax self.mel_fmax = mel_fmax
self.n_mels = n_mels self.n_mels = n_mels
self.use_mel = use_mel self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None self.mel_basis = None
if use_mel: if use_mel:
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel: if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S) S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S return S
def _build_mel_basis(self): def _build_mel_basis(self):
@ -87,6 +96,14 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
) )
self.mel_basis = torch.from_numpy(mel_basis).float() self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class AudioProcessor(object): class AudioProcessor(object):
@ -97,33 +114,93 @@ class AudioProcessor(object):
of the class with the model config. They are not meaningful for all the arguments. of the class with the model config. They are not meaningful for all the arguments.
Args: Args:
sample_rate (int, optional): target audio sampling rate. Defaults to None. sample_rate (int, optional):
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. target audio sampling rate. Defaults to None.
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB. resample (bool, optional):
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None. enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None. num_mels (int, optional):
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. number of melspectrogram dimensions. Defaults to None.
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. log_func (int, optional):
fft_size (int, optional): FFT window size for STFT. Defaults to 1024. log exponent used for converting spectrogram aplitude to DB.
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. min_level_db (int, optional):
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None. minimum db threshold for the computed melspectrograms. Defaults to None.
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None. frame_shift_ms (int, optional):
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None. milliseconds of frames between STFT columns. Defaults to None.
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20. frame_length_ms (int, optional):
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'. milliseconds of STFT window length. Defaults to None.
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None. hop_length (int, optional):
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False. number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False. win_length (int, optional):
stats_path (str, optional): Path to the computed stats file. Defaults to None. STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
verbose (bool, optional): enable/disable logging. Defaults to True.
ref_level_db (int, optional):
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional):
FFT window size for STFT. Defaults to 1024.
power (int, optional):
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional):
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional):
enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional):
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional):
```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional):
Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional):
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional):
Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional):
enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional):
DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional):
enable/disable signal normalization. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
""" """
def __init__( def __init__(
@ -153,6 +230,8 @@ class AudioProcessor(object):
do_trim_silence=False, do_trim_silence=False,
trim_db=60, trim_db=60,
do_sound_norm=False, do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
stats_path=None, stats_path=None,
verbose=True, verbose=True,
**_, **_,
@ -181,6 +260,8 @@ class AudioProcessor(object):
self.do_trim_silence = do_trim_silence self.do_trim_silence = do_trim_silence
self.trim_db = trim_db self.trim_db = trim_db
self.do_sound_norm = do_sound_norm self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
self.stats_path = stats_path self.stats_path = stats_path
# setup exp_func for db to amp conversion # setup exp_func for db to amp conversion
if log_func == "np.log": if log_func == "np.log":
@ -381,7 +462,6 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Decibels spectrogram. np.ndarray: Decibels spectrogram.
""" """
return self.spec_gain * _log(np.maximum(1e-5, x), self.base) return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
# pylint: disable=no-self-use # pylint: disable=no-self-use
@ -448,7 +528,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y)) D = self._stft(self.apply_preemphasis(y))
else: else:
D = self._stft(y) D = self._stft(y)
S = self._amp_to_db(np.abs(D)) if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D))
else:
S = np.abs(D)
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
def melspectrogram(self, y: np.ndarray) -> np.ndarray: def melspectrogram(self, y: np.ndarray) -> np.ndarray:
@ -457,7 +540,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y)) D = self._stft(self.apply_preemphasis(y))
else: else:
D = self._stft(y) D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
else:
S = self._linear_to_mel(np.abs(D))
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:

View File

@ -1,15 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime
import glob
import importlib import importlib
import os import os
import re import re
import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import fsspec
import torch import torch
@ -58,23 +57,22 @@ def get_commit_hash():
return commit return commit
def create_experiment_folder(root_path, model_name): def get_experiment_folder_path(root_path, model_name):
"""Create a folder with the current date and time""" """Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder)) print(" > Experiment folder: {}".format(output_folder))
return output_folder return output_folder
def remove_experiment_folder(experiment_path): def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder""" """Check folder if there is a checkpoint, otherwise remove the folder"""
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
if not checkpoint_files: if not checkpoint_files:
if os.path.exists(experiment_path): if fs.exists(experiment_path):
shutil.rmtree(experiment_path, ignore_errors=True) fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path)) print(" ! Run is removed from {}".format(experiment_path))
else: else:
print(" ! Run is kept in {}".format(experiment_path)) print(" ! Run is kept in {}".format(experiment_path))

View File

@ -1,9 +1,11 @@
import datetime import datetime
import glob import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
from shutil import copyfile import shutil
from typing import Any
import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -24,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
def copy_model_files(config, out_path, new_fields): def copy_model_files(config: Coqpit, out_path, new_fields):
"""Copy config.json and other model files to training folder and add """Copy config.json and other model files to training folder and add
new fields. new fields.
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
copy_config_path = os.path.join(out_path, "config.json") copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields # add extra information fields
config.update(new_fields, allow_new=True) config.update(new_fields, allow_new=True)
config.save_json(copy_config_path) # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available # copy model stats file if available
if config.audio.stats_path is not None: if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy") copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path): filesystem = fsspec.get_mapper(copy_stats_path).fs
copyfile( if not filesystem.exists(copy_stats_path):
config.audio.stats_path, with fsspec.open(config.audio.stats_path, "rb") as source_file:
copy_stats_path, with fsspec.open(copy_stats_path, "wb") as target_file:
) shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try: try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError: except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"]) model.load_state_dict(state["model"])
if use_cuda: if use_cuda:
model.cuda() model.cuda()
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
return model, state return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"): if hasattr(model, "module"):
model_state = model.module.state_dict() model_state = model.module.state_dict()
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
state.update(kwargs) state.update(kwargs)
torch.save(state, output_path) save_fsspec(state, output_path)
def save_checkpoint( def save_checkpoint(
@ -147,18 +178,16 @@ def save_best_model(
model_loss=current_loss, model_loss=current_loss,
**kwargs, **kwargs,
) )
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully # only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after): if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names: for model_name in model_names:
if os.path.basename(model_name) == best_model_name: if os.path.basename(model_name) != best_model_name:
continue fs.rm(model_name)
os.remove(model_name) # create a shortcut which always points to the currently best model
# create symlink to best model for convinience shortcut_name = "best_model.pth.tar"
link_name = "best_model.pth.tar" shortcut_path = os.path.join(out_path, shortcut_name)
link_path = os.path.join(out_path, link_name) fs.copy(checkpoint_path, shortcut_path)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
best_loss = current_loss best_loss = current_loss
return best_loss return best_loss

View File

@ -1,2 +1,24 @@
from TTS.utils.logging.console_logger import ConsoleLogger from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.logging.tensorboard_logger import TensorboardLogger from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config):
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
return dashboard_logger

View File

@ -38,7 +38,7 @@ class ConsoleLogger:
def print_train_start(self): def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict): def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict):
indent = " | > " indent = " | > "
print() print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
@ -50,13 +50,6 @@ class ConsoleLogger:
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else: else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value) log_text += "{}{}: {:.5f} \n".format(indent, key, value)
for idx, (key, value) in enumerate(log_dict.items()):
if isinstance(value, list):
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
else:
log_text += f"{indent}{key}: {value}"
if idx < len(log_dict) - 1:
log_text += "\n"
print(log_text, flush=True) print(log_text, flush=True)
# pylint: disable=unused-argument # pylint: disable=unused-argument

View File

@ -7,10 +7,8 @@ class TensorboardLogger(object):
def __init__(self, log_dir, model_name): def __init__(self, log_dir, model_name):
self.model_name = model_name self.model_name = model_name
self.writer = SummaryWriter(log_dir) self.writer = SummaryWriter(log_dir)
self.train_stats = {}
self.eval_stats = {}
def tb_model_weights(self, model, step): def model_weights(self, model, step):
layer_num = 1 layer_num = 1
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.numel() == 1: if param.numel() == 1:
@ -41,32 +39,41 @@ class TensorboardLogger(object):
except RuntimeError: except RuntimeError:
traceback.print_exc() traceback.print_exc()
def tb_train_step_stats(self, step, stats): def train_step_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats): def train_epoch_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step) self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
def tb_train_figures(self, step, figures): def train_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step) self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
def tb_train_audios(self, step, audios, sample_rate): def train_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate) self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
def tb_eval_stats(self, step, stats): def eval_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step) self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
def tb_eval_figures(self, step, figures): def eval_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step) self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
def tb_eval_audios(self, step, audios, sample_rate): def eval_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate) self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
def tb_test_audios(self, step, audios, sample_rate): def test_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate) self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
def tb_test_figures(self, step, figures): def test_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step) self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
def tb_add_text(self, title, text, step): def add_text(self, title, text, step):
self.writer.add_text(title, text, step) self.writer.add_text(title, text, step)
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
yield
def flush(self):
self.writer.flush()
def finish(self):
self.writer.close()

View File

@ -0,0 +1,111 @@
# pylint: disable=W0613
import traceback
from pathlib import Path
try:
import wandb
from wandb import finish, init # pylint: disable=W0611
except ImportError:
wandb = None
class WandbLogger:
def __init__(self, **kwargs):
if not wandb:
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
self.run = None
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.model_name = self.run.config.model
self.log_dict = {}
def model_weights(self, model):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()})
else:
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
layer_num += 1
def dict_to_scalar(self, scope_name, stats):
for key, value in stats.items():
self.log_dict["{}/{}".format(scope_name, key)] = value
def dict_to_figure(self, scope_name, figures):
for key, value in figures.items():
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
def dict_to_audios(self, scope_name, audios, sample_rate):
for key, value in audios.items():
if value.dtype == "float16":
value = value.astype("float32")
try:
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
except RuntimeError:
traceback.print_exc()
def log(self, log_dict, prefix="", flush=False):
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
if flush: # for cases where you don't want to accumulate data
self.flush()
def train_step_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
def train_epoch_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
def train_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
def train_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
def eval_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
def eval_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
def eval_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
def test_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
def test_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
def add_text(self, title, text, step):
pass
def flush(self):
if self.run:
wandb.log(self.log_dict)
self.log_dict = {}
def finish(self):
if self.run:
self.run.finish()
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
if not self.run:
return
name = "_".join([self.run.id, name])
artifact = wandb.Artifact(name, type=artifact_type)
data_path = Path(file_or_dir)
if data_path.is_dir():
artifact.add_dir(str(data_path))
elif data_path.is_file():
artifact.add_file(str(data_path))
self.run.log_artifact(artifact, aliases=aliases)

View File

@ -64,6 +64,7 @@ class ModelManager(object):
def list_models(self): def list_models(self):
print(" Name format: type/language/dataset/model") print(" Name format: type/language/dataset/model")
models_name_list = [] models_name_list = []
model_count = 1
for model_type in self.models_dict: for model_type in self.models_dict:
for lang in self.models_dict[model_type]: for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]: for dataset in self.models_dict[model_type][lang]:
@ -71,10 +72,11 @@ class ModelManager(object):
model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]") print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else: else:
print(f" >: {model_type}/{lang}/{dataset}/{model}") print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return models_name_list return models_name_list
def download_model(self, model_name): def download_model(self, model_name):

View File

@ -12,7 +12,6 @@ from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.synthesis import synthesis, trim_silence
from TTS.tts.utils.text import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@ -103,6 +102,34 @@ class Synthesizer(object):
self.num_speakers = self.speaker_manager.num_speakers self.num_speakers = self.speaker_manager.num_speakers
self.d_vector_dim = self.speaker_manager.d_vector_dim self.d_vector_dim = self.speaker_manager.d_vector_dim
def _set_tts_speaker_file(self):
"""Set the TTS speaker file used by a multi-speaker model."""
# setup if multi-speaker settings are in the global model config
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
if self.tts_config.use_d_vector_file:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
)
self.tts_config["d_vector_file"] = self.tts_speakers_file
else:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"]
)
# setup if multi-speaker settings are in the model args config
if (
self.tts_speakers_file is None
and hasattr(self.tts_config, "model_args")
and hasattr(self.tts_config.model_args, "use_speaker_embedding")
and self.tts_config.model_args.use_speaker_embedding
):
_args = self.tts_config.model_args
if _args.use_d_vector_file:
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"]
_args["d_vector_file"] = self.tts_speakers_file
else:
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"]
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model. """Load the TTS model.
@ -113,29 +140,15 @@ class Synthesizer(object):
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
global symbols, phonemes
self.tts_config = load_config(tts_config_path) self.tts_config = load_config(tts_config_path)
self.use_phonemes = self.tts_config.use_phonemes self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
if self.tts_config.has("characters") and self.tts_config.characters:
symbols, phonemes = make_symbols(**self.tts_config.characters)
if self.use_phonemes:
self.input_size = len(phonemes)
else:
self.input_size = len(symbols)
if self.tts_config.use_speaker_embedding is True:
self.tts_speakers_file = (
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
)
self.tts_config["d_vector_file"] = self.tts_speakers_file
self.tts_model = setup_tts_model(config=self.tts_config) self.tts_model = setup_tts_model(config=self.tts_config)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
self._set_tts_speaker_file()
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model. """Load the vocoder model.
@ -187,15 +200,22 @@ class Synthesizer(object):
""" """
start_time = time.time() start_time = time.time()
wavs = [] wavs = []
speaker_embedding = None
sens = self.split_into_sentences(text) sens = self.split_into_sentences(text)
print(" > Text splitted to sentences.") print(" > Text splitted to sentences.")
print(sens) print(sens)
# handle multi-speaker
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file: if self.tts_speakers_file:
# get the speaker embedding from the saved d_vectors.
if speaker_idx and isinstance(speaker_idx, str): if speaker_idx and isinstance(speaker_idx, str):
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0] if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
elif not speaker_idx and not speaker_wav: elif not speaker_idx and not speaker_wav:
raise ValueError( raise ValueError(
" [!] Look like you use a multi-speaker model. " " [!] Look like you use a multi-speaker model. "
@ -224,14 +244,14 @@ class Synthesizer(object):
CONFIG=self.tts_config, CONFIG=self.tts_config,
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
ap=self.ap, ap=self.ap,
speaker_id=None, speaker_id=speaker_id,
style_wav=style_wav, style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
d_vector=speaker_embedding, d_vector=speaker_embedding,
) )
waveform = outputs["wav"] waveform = outputs["wav"]
mel_postnet_spec = outputs["model_outputs"] mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
if not use_gl: if not use_gl:
# denormalize tts output based on tts audio config # denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T

View File

@ -1,5 +1,5 @@
import importlib import importlib
from typing import Dict from typing import Dict, List
import torch import torch
@ -48,7 +48,7 @@ def get_scheduler(
def get_optimizer( def get_optimizer(
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
) -> torch.optim.Optimizer: ) -> torch.optim.Optimizer:
"""Find, initialize and return a optimizer. """Find, initialize and return a optimizer.
@ -66,4 +66,6 @@ def get_optimizer(
optimizer = getattr(module, "RAdam") optimizer = getattr(module, "RAdam")
else: else:
optimizer = getattr(torch.optim, optimizer_name) optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=lr, **optimizer_params) if model is not None:
parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params)

View File

@ -24,8 +24,10 @@ def setup_model(config: Coqpit):
elif config.model.lower() == "wavegrad": elif config.model.lower() == "wavegrad":
MyModel = getattr(MyModel, "Wavegrad") MyModel = getattr(MyModel, "Wavegrad")
else: else:
MyModel = getattr(MyModel, to_camel(config.model)) try:
raise ValueError(f"Model {config.model} not exist!") MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
model = MyModel(config) model = MyModel(config)
return model return model

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path. checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False. eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
""" """
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models # band-aid for older than v0.0.15 GAN models
if "model_disc" in state: if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval) self.model_g.load_checkpoint(config, checkpoint_path, eval)

View File

@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
] ]
) )
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
Periods are suggested to be prime numbers to reduce the overlap between each discriminator. Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
""" """
def __init__(self): def __init__(self, use_spectral_norm=False):
super().__init__() super().__init__()
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[ [
DiscriminatorP(2), DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3), DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5), DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7), DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11), DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
] ]
) )
@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
x (Tensor): input waveform. x (Tensor): input waveform.
Returns: Returns:
[List[Tensor]]: list of scores from each discriminator. [List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes: Shapes:

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -168,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
upsample_initial_channel, upsample_initial_channel,
upsample_factors, upsample_factors,
inference_padding=5, inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
): ):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
@ -216,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
# post convolution layer # post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
def forward(self, x): if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
def forward(self, x, g=None):
""" """
Args: Args:
x (Tensor): conditioning input tensor. x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns: Returns:
Tensor: output waveform. Tensor: output waveform.
@ -231,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T] Tensor: [B, 1, T]
""" """
o = self.conv_pre(x) o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE) o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o) o = self.ups[i](o)
@ -275,7 +292,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack from TTS.vocoder.layers.melgan import ResidualStack
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -3,6 +3,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample from TTS.vocoder.layers.upsample import ConvUpsample
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -1,3 +1,5 @@
from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module): class UnivnetGenerator(torch.nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
hidden_channels, hidden_channels: int,
cond_channels, cond_channels: int,
upsample_factors, upsample_factors: List[int],
lvc_layers_each_block, lvc_layers_each_block: int,
lvc_kernel_size, lvc_kernel_size: int,
kpnet_hidden_channels, kpnet_hidden_channels: int,
kpnet_conv_size, kpnet_conv_size: int,
dropout, dropout: float,
use_weight_norm=True, use_weight_norm=True,
): ):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels

View File

@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.models.base_vocoder import BaseVocoder
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -1,6 +1,7 @@
import datetime import datetime
import pickle import pickle
import fsspec
import tensorflow as tf import tensorflow as tf
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
"""Load TF Vocoder model""" """Load TF Vocoder model"""
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None: if output_path is not None:
# same model binary if outputpath is provided # same model binary if outputpath is provided
with open(output_path, "wb") as f: with fsspec.open(output_path, "wb") as f:
f.write(tflite_model) f.write(tflite_model)
return None return None
return tflite_model return tflite_model

View File

@ -1,8 +1,11 @@
from typing import Dict
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec): def interpolate_vocoder_input(scale_factor, spec):
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec return spec
def plot_results(y_hat, y, ap, name_prefix): def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
"""Plot vocoder model results""" """Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
if name_prefix is None:
name_prefix = ""
# select an instance from batch # select an instance from batch
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy() y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze(0).detach().cpu().numpy() y = y[0].squeeze().detach().cpu().numpy()
spec_fake = ap.melspectrogram(y_hat).T spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T spec_real = ap.melspectrogram(y).T

View File

@ -2,4 +2,5 @@ furo
myst-parser == 0.15.1 myst-parser == 0.15.1
sphinx == 4.0.2 sphinx == 4.0.2
sphinx_inline_tabs sphinx_inline_tabs
sphinx_copybutton sphinx_copybutton
linkify-it-py

View File

@ -68,6 +68,8 @@ extensions = [
"sphinx_inline_tabs", "sphinx_inline_tabs",
] ]
myst_enable_extensions = ['linkify',]
# 'sphinxcontrib.katex', # 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel', # 'sphinx.ext.autosectionlabel',

View File

@ -44,6 +44,7 @@
:caption: `tts` Models :caption: `tts` Models
models/glow_tts.md models/glow_tts.md
models/vits.md
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@ -0,0 +1,22 @@
# Glow TTS
Glow TTS is a normalizing flow model for text-to-speech. It is built on the generic Glow model that is previously
used in computer vision and vocoder models. It uses "monotonic alignment search" (MAS) to fine the text-to-speech alignment
and uses the output to train a separate duration predictor network for faster inference run-time.
## Important resources & papers
- GlowTTS: https://arxiv.org/abs/2005.11129
- Glow (Generative Flow with invertible 1x1 Convolutions): https://arxiv.org/abs/1807.03039
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
## GlowTTS Config
```{eval-rst}
.. autoclass:: TTS.tts.configs.glow_tts_config.GlowTTSConfig
:members:
```
## GlowTTS Model
```{eval-rst}
.. autoclass:: TTS.tts.models.glow_tts.GlowTTS
:members:
```

View File

@ -0,0 +1,33 @@
# VITS
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
It is a feed-forward model with x67.12 real-time factor on a GPU.
## Important resources & papers
- VITS: https://arxiv.org/pdf/2106.06103.pdf
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
- Generative Adversarial Networks: https://arxiv.org/abs/1406.2661
- HiFiGAN: https://arxiv.org/abs/2010.05646
- Normalizing Flows: https://blog.evjang.com/2018/01/nf1.html
## VitsConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.vits_config.VitsConfig
:members:
```
## VitsArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.vits.VitsArgs
:members:
```
## Vits Model
```{eval-rst}
.. autoclass:: TTS.tts.models.vits.Vits
:members:
```

View File

@ -85,6 +85,7 @@ We still support running training from CLI like in the old days. The same traini
```json ```json
{ {
"run_name": "my_run",
"model": "glow_tts", "model": "glow_tts",
"batch_size": 32, "batch_size": 32,
"eval_batch_size": 16, "eval_batch_size": 16,

View File

@ -25,9 +25,7 @@
"import umap\n", "import umap\n",
"\n", "\n",
"from TTS.speaker_encoder.model import SpeakerEncoder\n", "from TTS.speaker_encoder.model import SpeakerEncoder\n",
"from TTS.utils.audio import AudioProcessor "from TTS.utils.audio import AudioProcessor\n",
\n",
"from TTS.tts.utils.generic_utils import load_config\n", "from TTS.tts.utils.generic_utils import load_config\n",
"\n", "\n",
"from bokeh.io import output_notebook, show\n", "from bokeh.io import output_notebook, show\n",

View File

@ -1,12 +1,12 @@
import os import os
from TTS.tts.configs import AlignTTSConfig from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import BaseDatasetConfig from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")) dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
config = AlignTTSConfig( config = AlignTTSConfig(
batch_size=32, batch_size=32,
eval_batch_size=16, eval_batch_size=16,
@ -23,8 +23,8 @@ config = AlignTTSConfig(
print_eval=True, print_eval=True,
mixed_precision=False, mixed_precision=False,
output_path=output_path, output_path=output_path,
datasets=[dataset_config] datasets=[dataset_config],
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -1,12 +1,12 @@
import os import os
from TTS.tts.configs import GlowTTSConfig from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import BaseDatasetConfig from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")) dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
config = GlowTTSConfig( config = GlowTTSConfig(
batch_size=32, batch_size=32,
eval_batch_size=16, eval_batch_size=16,
@ -23,8 +23,8 @@ config = GlowTTSConfig(
print_eval=True, print_eval=True,
mixed_precision=False, mixed_precision=False,
output_path=output_path, output_path=output_path,
datasets=[dataset_config] datasets=[dataset_config],
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -24,6 +24,6 @@ config = HifiganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -1,8 +1,7 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.configs import MultibandMelganConfig
from TTS.trainer import init_training, Trainer, TrainingArgs
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
config = MultibandMelganConfig( config = MultibandMelganConfig(
@ -25,6 +24,6 @@ config = MultibandMelganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -1,6 +1,5 @@
import os import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.configs import UnivnetConfig
@ -25,6 +24,6 @@ config = UnivnetConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -0,0 +1,52 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)
config = VitsConfig(
audio=audio_config,
run_name="vits_ljspeech",
batch_size=48,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
print_eval=True,
mixed_precision=True,
max_seq_len=5000,
output_path=output_path,
datasets=[dataset_config],
)
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
trainer.fit()

View File

@ -1,10 +1,8 @@
import os import os
from TTS.trainer import Trainer, init_training from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import TrainingArgs
from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.configs import WavegradConfig
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
config = WavegradConfig( config = WavegradConfig(
batch_size=32, batch_size=32,
@ -24,6 +22,6 @@ config = WavegradConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit() trainer.fit()

View File

@ -1,9 +1,8 @@
import os import os
from TTS.trainer import Trainer, init_training, TrainingArgs from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.configs import WavernnConfig
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
config = WavernnConfig( config = WavernnConfig(
batch_size=64, batch_size=64,
@ -25,6 +24,6 @@ config = WavernnConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True) trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
trainer.fit() trainer.fit()

View File

@ -24,3 +24,4 @@ mecab-python3==1.0.3
unidic-lite==1.0.8 unidic-lite==1.0.8
# gruut+supported langs # gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
fsspec>=2021.04.0

View File

@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=True, compute_linear_spec=True,
return_wav=True,
ap=self.ap, ap=self.ap,
meta_data=items, meta_data=items,
characters=c.characters, characters=c.characters,
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
mel_lengths = data[5] mel_lengths = data[5]
stop_target = data[6] stop_target = data[6]
item_idx = data[7] item_idx = data[7]
wavs = data[11]
neg_values = text_input[text_input < 0] neg_values = text_input[text_input < 0]
check_count = len(neg_values) check_count = len(neg_values)
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert isinstance(speaker_name[0], str) assert isinstance(speaker_name[0], str)
assert linear_input.shape[0] == c.batch_size assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
assert mel_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio["num_mels"] assert mel_input.shape[2] == c.audio["num_mels"]
assert (
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
assert abs(mel_diff.sum()) < 1e-5
# check normalization ranges # check normalization ranges
if self.ap.symmetric_norm: if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm assert mel_input.max() <= self.ap.max_norm

View File

@ -27,6 +27,7 @@ config = AlignTTSConfig(
"Be a voice, not an echo.", "Be a voice, not an echo.",
], ],
) )
config.audio.do_trim_silence = True config.audio.do_trim_silence = True
config.audio.trim_db = 60 config.audio.trim_db = 60
config.save_json(config_path) config.save_json(config_path)

View File

@ -29,6 +29,7 @@ config = Tacotron2Config(
"Be a voice, not an echo.", "Be a voice, not an echo.",
], ],
d_vector_file="tests/data/ljspeech/speakers.json", d_vector_file="tests/data/ljspeech/speakers.json",
d_vector_dim=256,
max_decoder_steps=50, max_decoder_steps=50,
) )

View File

@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
"""Test vanilla Tacotron2 model."""
def test_train_step(self): # pylint: disable=no-self-use def test_train_step(self): # pylint: disable=no-self-use
config = config_global.copy() config = config_global.copy()
config.use_speaker_embedding = False
config.num_speakers = 1
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(config).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5):
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class MultiSpeakerTacotronTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with speaker embedding layer"""
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 5
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0] input_lengths = torch.sort(input_lengths, descending=True)[0]
@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = MSELossMasked(seq_len_norm=False).to(device) criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron2(config).to(device) model = Tacotron2(config).to(device)
model.train() model.train()
model_ref = copy.deepcopy(model) model_ref = copy.deepcopy(model)
@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase):
count += 1 count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.rand(8, 55).to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron2(config).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for i in range(5):
outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
)
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class TacotronGSTTrainTest(unittest.TestCase): class TacotronGSTTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding"""
# pylint: disable=no-self-use # pylint: disable=no-self-use
def test_train_step(self): def test_train_step(self):
# with random gst mel style # with random gst mel style
config = config_global.copy() config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 10
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0] input_lengths = torch.sort(input_lengths, descending=True)[0]
@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase):
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
@staticmethod @staticmethod
def test_train_step(): def test_train_step():
config = config_global.copy() config = config_global.copy()
config.use_d_vector_file = True
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0] input_lengths = torch.sort(input_lengths, descending=True)[0]

View File

@ -0,0 +1,55 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = Tacotron2Config(
r=5,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
test_sentences=[
"Be a voice, not an echo.",
],
print_eval=True,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
f"--coqpit.output_path file://{output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
)
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase):
@staticmethod @staticmethod
def test_train_step(): def test_train_step():
config = config_global.copy() config = config_global.copy()
config.use_speaker_embedding = False
config.num_speakers = 1
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for _ in range(5):
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 5
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128 input_lengths[-1] = 128
@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = L1LossMasked(seq_len_norm=False).to(device) criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train() model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase):
count += 1 count += 1
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = config_global.copy()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_embeddings = torch.rand(8, 55).to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()) :, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
config.d_vector_dim = 55
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
model.train()
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=config.lr)
for _ in range(5):
outputs = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
)
optimizer.zero_grad()
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class TacotronGSTTrainTest(unittest.TestCase): class TacotronGSTTrainTest(unittest.TestCase):
@staticmethod @staticmethod
def test_train_step(): def test_train_step():
config = config_global.copy() config = config_global.copy()
config.use_speaker_embedding = True
config.num_speakers = 10
config.use_gst = True
config.gst = GSTConfig()
# with random gst mel style # with random gst mel style
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device)
@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod @staticmethod
def test_train_step(): def test_train_step():
config = config_global.copy() config = config_global.copy()
config.use_d_vector_file = True
config.use_gst = True
config.gst = GSTConfig()
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128 input_lengths[-1] = 128

View File

@ -0,0 +1,54 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)