From d4f1ccd3ede16e48654838d73936cc0d23b7655c Mon Sep 17 00:00:00 2001 From: Eren G Date: Tue, 17 Jul 2018 15:59:31 +0200 Subject: [PATCH] Move things into main and set thread --- train.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/train.py b/train.py index 2c380f57..d0a04122 100644 --- a/train.py +++ b/train.py @@ -27,30 +27,11 @@ from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron from layers.losses import L1LossMasked + torch.manual_seed(1) +torch.set_num_threads(4) use_cuda = torch.cuda.is_available() -parser = argparse.ArgumentParser() -parser.add_argument('--restore_path', type=str, - help='Folder path to checkpoints', default=0) -parser.add_argument('--config_path', type=str, - help='path to config file for training',) -parser.add_argument('--debug', type=bool, default=False, - help='do not ask for git has before run.') -args = parser.parse_args() - -# setup output paths and read configs -c = load_config(args.config_path) -_ = os.path.dirname(os.path.realpath(__file__)) -OUT_PATH = os.path.join(_, c.output_path) -OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug) -CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') -shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) - -# setup tensorboard -LOG_DIR = OUT_PATH -tb = SummaryWriter(LOG_DIR) - def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch): model = model.train() @@ -440,6 +421,27 @@ def main(args): if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--restore_path', type=str, + help='Folder path to checkpoints', default=0) + parser.add_argument('--config_path', type=str, + help='path to config file for training',) + parser.add_argument('--debug', type=bool, default=False, + help='do not ask for git has before run.') + args = parser.parse_args() + + # setup output paths and read configs + c = load_config(args.config_path) + _ = os.path.dirname(os.path.realpath(__file__)) + OUT_PATH = os.path.join(_, c.output_path) + OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug) + CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') + shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) + + # setup tensorboard + LOG_DIR = OUT_PATH + tb = SummaryWriter(LOG_DIR) + try: main(args) except KeyboardInterrupt: