From 49bac724c00ef5e2370b4a084e10fd1ed04b500f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 12 Jul 2022 18:49:58 +0200 Subject: [PATCH] Implement VitsAudioConfig (#1556) * Implement VitsAudioConfig * Update VITS LJSpeech recipe * Update VITS VCTK recipe * Make style * Add missing decorator * Add missing param * Make style * Update recipes * Fix test * Bug fix * Exclude tests folder * Make linter * Make style --- MANIFEST.in | 3 +- TTS/tts/configs/vits_config.py | 6 +++- TTS/tts/layers/losses.py | 2 +- TTS/tts/models/vits.py | 16 +++++++++++ TTS/tts/utils/ssim.py | 4 +-- TTS/utils/synthesizer.py | 2 +- .../ljspeech/fast_pitch/train_fast_pitch.py | 1 - .../ljspeech/fast_speech/train_fast_speech.py | 1 - .../speedy_speech/train_speedy_speech.py | 1 - .../train_capacitron_t2.py | 1 - recipes/ljspeech/vits_tts/train_vits.py | 23 ++++----------- .../multilingual/vits_tts/train_vits_tts.py | 15 ++-------- .../speedy_speech/train_speedy_speech.py | 1 - recipes/thorsten_DE/vits_tts/train_vits.py | 13 ++------- recipes/vctk/vits/train_vits.py | 22 +++------------ setup.py | 2 +- tests/tts_tests/test_vits.py | 28 ++++++++++++++++--- 17 files changed, 65 insertions(+), 76 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 82ecadcb..321d3999 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -11,4 +11,5 @@ recursive-include TTS *.md recursive-include TTS *.py recursive-include TTS *.pyx recursive-include images *.png - +recursive-exclude tests * +prune tests* diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index a8c7f91d..df9116f3 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.vits import VitsArgs +from TTS.tts.models.vits import VitsArgs, VitsAudioConfig @dataclass @@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig): model_args (VitsArgs): Model architecture arguments. Defaults to `VitsArgs()`. + audio (VitsAudioConfig): + Audio processing configuration. Defaults to `VitsAudioConfig()`. + grad_clip (List): Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. @@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig): model: str = "vits" # model specific params model_args: VitsArgs = field(default_factory=VitsArgs) + audio: VitsAudioConfig = VitsAudioConfig() # optimizer grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 4430d9ff..816813c8 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -137,7 +137,7 @@ class SSIMLoss(torch.nn.Module): if ssim_loss.item() < 0.0: print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") - ssim_loss = torch.tensor([0.0]) + ssim_loss = torch.tensor([0.0]) return ssim_loss diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9263c0b1..f0920bd6 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm return spec +############################# +# CONFIGS +############################# + + +@dataclass +class VitsAudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + ############################## # DATASET ############################## diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 2bca1be5..4bc3befc 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: """ if reduction == "none": return x - elif reduction == "mean": + if reduction == "mean": return x.mean(dim=0) - elif reduction == "sum": + if reduction == "sum": return x.sum(dim=0) raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2f319809..170bb223 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -307,7 +307,7 @@ class Synthesizer(object): waveform = waveform.squeeze() # trim silence - if self.tts_config.audio["do_trim_silence"] is True: + if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]: waveform = trim_silence(waveform, self.tts_model.ap) wavs += list(waveform) diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index a84658f3..1c0e4702 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -54,7 +54,6 @@ config = FastPitchConfig( print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index 0245dd93..ab7e8841 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -53,7 +53,6 @@ config = FastSpeechConfig( print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 1ab3db1c..fd3c8679 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -46,7 +46,6 @@ config = SpeedySpeechConfig( print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py b/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py index 6bb0aed7..a1882451 100644 --- a/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py +++ b/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py @@ -68,7 +68,6 @@ config = Tacotron2Config( print_step=25, print_eval=True, mixed_precision=False, - sort_by_audio_len=True, seq_len_norm=True, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index c070b3f1..94e230a1 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -2,11 +2,10 @@ import os from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits +from TTS.tts.models.vits import Vits, VitsAudioConfig from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") ) -audio_config = BaseAudioConfig( - sample_rate=22050, - win_length=1024, - hop_length=256, - num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=45, - mel_fmin=0, - mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, +audio_config = VitsAudioConfig( + sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None ) config = VitsConfig( @@ -37,7 +23,7 @@ config = VitsConfig( batch_size=32, eval_batch_size=16, batch_group_size=5, - num_loader_workers=0, + num_loader_workers=8, num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, @@ -52,6 +38,7 @@ config = VitsConfig( mixed_precision=True, output_path=output_path, datasets=[dataset_config], + cudnn_benchmark=False, ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 0e650ade..0a9cced4 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -3,11 +3,10 @@ from glob import glob from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs +from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -22,22 +21,13 @@ dataset_config = [ for path in dataset_paths ] -audio_config = BaseAudioConfig( +audio_config = VitsAudioConfig( sample_rate=16000, win_length=1024, hop_length=256, num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=False, - trim_db=23.0, mel_fmin=0, mel_fmax=None, - spec_gain=1.0, - signal_norm=True, - do_amp_to_db_linear=False, - resample=False, ) vitsArgs = VitsArgs( @@ -69,7 +59,6 @@ config = VitsConfig( use_language_weighted_sampler=True, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, min_audio_len=32 * 256 * 4, max_audio_len=160000, output_path=output_path, diff --git a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py index 1a4c8ec8..8f241306 100644 --- a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py +++ b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py @@ -60,7 +60,6 @@ config = SpeedySpeechConfig( "Dieser Kuchen ist großartig. Er ist so lecker und feucht.", "Vor dem 22. November 1963.", ], - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/thorsten_DE/vits_tts/train_vits.py b/recipes/thorsten_DE/vits_tts/train_vits.py index 86a7dfe6..25c57b64 100644 --- a/recipes/thorsten_DE/vits_tts/train_vits.py +++ b/recipes/thorsten_DE/vits_tts/train_vits.py @@ -2,11 +2,10 @@ import os from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits +from TTS.tts.models.vits import Vits, VitsAudioConfig from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.downloaders import download_thorsten_de @@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path): print("Downloading dataset") download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0]) -audio_config = BaseAudioConfig( +audio_config = VitsAudioConfig( sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=45, mel_fmin=0, mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, ) config = VitsConfig( diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 88fd7de9..814d0989 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -2,11 +2,10 @@ import os from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig( ) -audio_config = BaseAudioConfig( - sample_rate=22050, - win_length=1024, - hop_length=256, - num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=23.0, - mel_fmin=0, - mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, - resample=True, +audio_config = VitsAudioConfig( + sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None ) vitsArgs = VitsArgs( @@ -62,6 +47,7 @@ config = VitsConfig( max_text_len=325, # change this if you have a larger VRAM than 16GB output_path=output_path, datasets=[dataset_config], + cudnn_benchmark=False, ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/setup.py b/setup.py index 3c860949..f95d79f1 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ setup( # ext_modules=find_cython_extensions(), # package include_package_data=True, - packages=find_packages(include=["TTS*"]), + packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]), package_data={ "TTS": [ "VERSION", diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index b9cebb5a..7d474c20 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test from TTS.config import load_config from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec +from TTS.tts.models.vits import ( + Vits, + VitsArgs, + VitsAudioConfig, + amp_to_db, + db_to_amp, + load_audio, + spec_to_mel, + wav_to_mel, + wav_to_spec, +) from TTS.tts.utils.speakers import SpeakerManager LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") @@ -421,8 +431,10 @@ class TestVits(unittest.TestCase): self._check_parameter_changes(model, model_ref) def test_train_step_upsampling(self): + """Upsampling by the decoder upsampling layers""" # setup the model with torch.autograd.set_detect_anomaly(True): + audio_config = VitsAudioConfig(sample_rate=22050) model_args = VitsArgs( num_chars=32, spec_segment_size=10, @@ -430,7 +442,7 @@ class TestVits(unittest.TestCase): interpolate_z=False, upsample_rates_decoder=[8, 8, 4, 2], ) - config = VitsConfig(model_args=model_args) + config = VitsConfig(model_args=model_args, audio=audio_config) model = Vits(config).to(device) model.train() # model to train @@ -459,10 +471,18 @@ class TestVits(unittest.TestCase): self._check_parameter_changes(model, model_ref) def test_train_step_upsampling_interpolation(self): + """Upsampling by interpolation""" # setup the model with torch.autograd.set_detect_anomaly(True): - model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True) - config = VitsConfig(model_args=model_args) + audio_config = VitsAudioConfig(sample_rate=22050) + model_args = VitsArgs( + num_chars=32, + spec_segment_size=10, + encoder_sample_rate=11025, + interpolate_z=True, + upsample_rates_decoder=[8, 8, 2, 2], + ) + config = VitsConfig(model_args=model_args, audio=audio_config) model = Vits(config).to(device) model.train() # model to train