From 22822cd41c9e3f962863825d22b3250024263ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 08:31:10 +0000 Subject: [PATCH] Add LJSpeech SpeedySpeech recipe --- .../speedy_speech/train_speedy_speech.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 recipes/ljspeech/speedy_speech/train_speedy_speech.py diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py new file mode 100644 index 00000000..2882468f --- /dev/null +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -0,0 +1,68 @@ +import os + +from TTS.config import BaseAudioConfig, BaseDatasetConfig +from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.tts.configs import SpeedySpeechConfig +from TTS.utils.manage import ModelManager + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = SpeedySpeechConfig( + run_name="speedy_speech_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=4, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=50, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + max_seq_len=500000, + output_path=output_path, + datasets=[dataset_config], +) + +# compute alignments +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) + +# train the model +args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger) +trainer.fit()