diff --git a/recipes/librispeech/stt/deep_speech/train_deep_speech.py b/recipes/librispeech/stt/deep_speech/train_deep_speech.py index b691fc57..b10bf0b6 100644 --- a/recipes/librispeech/stt/deep_speech/train_deep_speech.py +++ b/recipes/librispeech/stt/deep_speech/train_deep_speech.py @@ -13,21 +13,27 @@ output_path = os.path.dirname(os.path.abspath(__file__)) if not os.path.exists("/home/ubuntu/librispeech/LibriSpeech/train-clean-100"): download_librispeech("/home/ubuntu/librispeech/", "train-clean-100") +if not os.path.exists("/home/ubuntu/librispeech/LibriSpeech/train-clean-360"): + download_librispeech("/home/ubuntu/librispeech/", "train-clean-360") +if not os.path.exists("/home/ubuntu/librispeech/LibriSpeech/train-other-500"): + download_librispeech("/home/ubuntu/librispeech/", "train-other-500") if not os.path.exists("/home/ubuntu/librispeech/LibriSpeech/dev-clean"): download_librispeech("/home/ubuntu/librispeech/", "dev-clean") -# train_dataset_config = BaseDatasetConfig( -# name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/train-clean-100" -# ) +train_dataset_config1 = BaseDatasetConfig( +name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/train-clean-100" +) -# eval_dataset_config = BaseDatasetConfig( -# name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/dev-clean" -# ) +train_dataset_config2 = BaseDatasetConfig( +name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/train-clean-360" +) -train_dataset_config = BaseDatasetConfig( - name="ljspeech", - meta_file_train="metadata.csv", - path="/home/ubuntu/ljspeech/LJSpeech-1.1/", +train_dataset_config3 = BaseDatasetConfig( +name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/train-other-500" +) + +eval_dataset_config = BaseDatasetConfig( +name="librispeech", meta_file_train=None, path="/home/ubuntu/librispeech/LibriSpeech/dev-clean" ) @@ -59,16 +65,16 @@ config = DeepSpeechConfig( mixed_precision=True, max_seq_len=500000, output_path=output_path, - train_datasets=[train_dataset_config], - # eval_datasets=[eval_dataset_config], + train_datasets=[train_dataset_config1, train_dataset_config2, train_dataset_config3], + eval_datasets=[eval_dataset_config] ) # init audio processor ap = AudioProcessor(**config.audio.to_dict()) # load training samples -train_samples, eval_samples = load_stt_samples(train_dataset_config, eval_split=True) -# eval_samples, _ = load_stt_samples(eval_dataset_config, eval_split=False) +train_samples, _ = load_stt_samples(config.train_datasets, eval_split=False) +eval_samples, _ = load_stt_samples(config.eval_datasets, eval_split=False) transcripts = [s["text"] for s in train_samples] # init tokenizer @@ -81,13 +87,11 @@ config.vocabulary = tokenizer.vocab_dict model = DeepSpeech(config) # init training and kick it 🚀 -# args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) trainer = Trainer( TrainingArgs(), config, output_path, model=model, - tokenizer=tokenizer, train_samples=train_samples, eval_samples=eval_samples, cudnn_benchmark=False,