From 4556c61902ca505b6570fc0f3c2f3a7b5d8a4238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 19 Apr 2022 09:18:49 +0200 Subject: [PATCH] Update fastpitche2e recipe --- TTS/tts/configs/fast_pitch_e2e_config.py | 26 ++++++++----- .../fast_pitch_e2e/train_fast_pitch_e2e.py | 37 +++++++------------ 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_e2e_config.py b/TTS/tts/configs/fast_pitch_e2e_config.py index 7f7126ec..f86cf459 100644 --- a/TTS/tts/configs/fast_pitch_e2e_config.py +++ b/TTS/tts/configs/fast_pitch_e2e_config.py @@ -2,12 +2,12 @@ from dataclasses import dataclass, field from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.forward_tts_e2e import ForwardTTSE2EArgs +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2eArgs @dataclass -class FastPitchE2EConfig(BaseTTSConfig): - """Configure `ForwardTTS` as FastPitch model. +class FastPitchE2eConfig(BaseTTSConfig): + """Configure `ForwardTTSE2e` as FastPitchE2e model. Example: @@ -103,13 +103,13 @@ class FastPitchE2EConfig(BaseTTSConfig): """ model: str = "fast_pitch_e2e_hifigan" - base_model: str = "forward_tts" + base_model: str = "forward_tts_e2e" # model specific params - # model_args: ForwardTTSE2EArgs = ForwardTTSE2EArgs(vocoder_config=HifiganConfig()) - model_args: ForwardTTSE2EArgs = ForwardTTSE2EArgs() + # model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs(vocoder_config=HifiganConfig()) + model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs() - # # multi-speaker settings + # multi-speaker settings # num_speakers: int = 0 # speakers_file: str = None # use_speaker_embedding: bool = False @@ -142,11 +142,19 @@ class FastPitchE2EConfig(BaseTTSConfig): binary_align_loss_alpha: float = 0.1 binary_loss_warmup_epochs: int = 150 - # decoder loss params + # dvocoder loss params disc_loss_alpha: float = 1.0 gen_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0 - mel_loss_alpha: float = 45.0 + mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) # data loader params return_wav: bool = True diff --git a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py index 31b3da9c..b8463cd1 100644 --- a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py +++ b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py @@ -2,12 +2,12 @@ import os from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig -from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2EConfig +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EArgs +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor + output_path = os.path.dirname(os.path.abspath(__file__)) @@ -15,30 +15,26 @@ output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", - # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"), ) -audio_config = BaseAudioConfig( +audio_config = ForwardTTSE2eAudio( sample_rate=22050, - do_trim_silence=True, - trim_db=60.0, - signal_norm=False, + hop_length=256, + win_length=1024, + fft_size=1024, mel_fmin=0.0, mel_fmax=8000, - spec_gain=1.0, - log_func="np.log", - ref_level_db=20, - preemphasis=0.0, + pitch_fmax=640.0, num_mels=80, ) # vocoder_config = HifiganConfig() -model_args = ForwardTTSE2EArgs() +model_args = ForwardTTSE2eArgs() -config = FastPitchE2EConfig( +config = FastPitchE2eConfig( run_name="fast_pitch_e2e_ljspeech", - run_description="don't detach vocoder input.", + run_description="Train like in FS2 paper.", model_args=model_args, audio=audio_config, batch_size=32, @@ -63,14 +59,9 @@ config = FastPitchE2EConfig( output_path=output_path, datasets=[dataset_config], start_by_longest=False, - binary_align_loss_alpha=0.0 + binary_align_loss_alpha=0.0, ) -# INITIALIZE THE AUDIO PROCESSOR -# Audio processor is used for feature extraction and audio I/O. -# It mainly serves to the dataloader and the training loggers. -ap = AudioProcessor.init_from_config(config) - # INITIALIZE THE TOKENIZER # Tokenizer is used to convert text to sequences of token IDs. # If characters are not defined in the config, default characters are passed to the config @@ -89,7 +80,7 @@ train_samples, eval_samples = load_tts_samples( ) # init the model -model = ForwardTTSE2E(config, ap, tokenizer, speaker_manager=None) +model = ForwardTTSE2e(config=config, tokenizer=tokenizer, speaker_manager=None) # init the trainer and 🚀 trainer = Trainer(