From dbfc489775844d7ad8270b8062cdc25cdfe77286 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 17 Apr 2018 10:05:50 -0700 Subject: [PATCH] Remove redundant dataset import --- train.py | 59 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/train.py b/train.py index 3786c0ee..7e91a77a 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder, count_parameters, check_update, get_commit_hash) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram -from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron from layers.losses import L1LossMasked @@ -296,41 +295,41 @@ def main(args): Dataset = getattr(mod, c.dataset+"Dataset") # Setup the dataset - train_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_train), - os.path.join(c.data_path, 'wavs'), - c.r, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power, - min_seq_len=c.min_seq_len - ) + train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power, + min_seq_len=c.min_seq_len + ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size, shuffle=False, collate_fn=train_dataset.collate_fn, drop_last=False, num_workers=c.num_loader_workers, pin_memory=True) - val_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_val), - os.path.join(c.data_path, 'wavs'), - c.r, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power - ) + val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn,