diff --git a/train.py b/train.py index 9b7df10a..f52d365d 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ from utils.generic_utils import ( from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked +from datasets.TTSDataset import MyDataset from utils.audio import AudioProcessor from utils.synthesis import synthesis from utils.logger import Logger @@ -44,8 +45,8 @@ def setup_loader(is_val=False): ap=ap, batch_group_size=0 if is_val else 8 * c.batch_size, min_seq_len=0 if is_val else c.min_seq_len, - max_seq_len=float("inf") if is_val else c.max_seq_len - cached=False if c.dataset ~= "tts_cache" else True) + max_seq_len=float("inf") if is_val else c.max_seq_len, + cached=False if c.dataset != "tts_cache" else True) loader = DataLoader( dataset, batch_size=c.eval_batch_size if is_val else c.batch_size, @@ -444,7 +445,7 @@ if __name__ == '__main__': default=False, help='Do not verify commit integrity to run training.') parser.add_argument( - '--data_path', type=str, help='Defines the data path. It overwrites config.json.', default='') + '--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.') args = parser.parse_args() # setup output paths and read configs @@ -467,8 +468,6 @@ if __name__ == '__main__': # Conditional imports preprocessor = importlib.import_module('datasets.preprocess') preprocessor = getattr(preprocessor, c.dataset.lower()) - MyDataset = importlib.import_module('datasets.' + c.data_loader) - MyDataset = getattr(MyDataset, "MyDataset") audio = importlib.import_module('utils.' + c.audio['audio_processor']) AudioProcessor = getattr(audio, 'AudioProcessor')