From 7cdeef1b5cefd2c40a9391141cb9ef365a114234 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 18 Dec 2018 12:58:09 +0100 Subject: [PATCH] bug fixes --- train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 7dd5b78d..f52d365d 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ from utils.generic_utils import ( from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked +from datasets.TTSDataset import MyDataset from utils.audio import AudioProcessor from utils.synthesis import synthesis from utils.logger import Logger @@ -44,8 +45,8 @@ def setup_loader(is_val=False): ap=ap, batch_group_size=0 if is_val else 8 * c.batch_size, min_seq_len=0 if is_val else c.min_seq_len, - max_seq_len=float("inf") if is_val else c.max_seq_len - cached=False if c.dataset ~= "tts_cache" else True) + max_seq_len=float("inf") if is_val else c.max_seq_len, + cached=False if c.dataset != "tts_cache" else True) loader = DataLoader( dataset, batch_size=c.eval_batch_size if is_val else c.batch_size, @@ -444,7 +445,7 @@ if __name__ == '__main__': default=False, help='Do not verify commit integrity to run training.') parser.add_argument( - '--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.') + '--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.') args = parser.parse_args() # setup output paths and read configs @@ -467,8 +468,6 @@ if __name__ == '__main__': # Conditional imports preprocessor = importlib.import_module('datasets.preprocess') preprocessor = getattr(preprocessor, c.dataset.lower()) - MyDataset = importlib.import_module('datasets.' + c.data_loader) - MyDataset = getattr(MyDataset, "MyDataset") audio = importlib.import_module('utils.' + c.audio['audio_processor']) AudioProcessor = getattr(audio, 'AudioProcessor')