From 1b59d8110ca29d311be9c2cd96488b2019b25ac9 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 28 Apr 2018 13:11:02 -0700 Subject: [PATCH] fix optimizer for restored model --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 71dd4dca..aa8e92ef 100644 --- a/train.py +++ b/train.py @@ -362,17 +362,18 @@ def main(args): checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) - print("\n > Model restored from step %d\n" % checkpoint['step']) + print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] start_epoch = 0 args.restore_step = checkpoint['step'] else: args.restore_step = 0 - print("\n > Starting a new training") + print(" > Starting a new training") if use_cuda: - model = nn.DataParallel(model.cuda()) + print(" > Using CUDA.") + model = nn.DataParallel(model).cuda() num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params))