diff --git a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py new file mode 100644 index 00000000..31b3da9c --- /dev/null +++ b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py @@ -0,0 +1,98 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig +from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2EConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EArgs +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, + num_mels=80, +) + +# vocoder_config = HifiganConfig() +model_args = ForwardTTSE2EArgs() + +config = FastPitchE2EConfig( + run_name="fast_pitch_e2e_ljspeech", + run_description="don't detach vocoder input.", + model_args=model_args, + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, + print_step=50, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + output_path=output_path, + datasets=[dataset_config], + start_by_longest=False, + binary_align_loss_alpha=0.0 +) + +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# init the model +model = ForwardTTSE2E(config, ap, tokenizer, speaker_manager=None) + +# init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples +) +trainer.fit()