From e6112f7b2d1cebf251f96664b3a621d4d08f67cd Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 16 May 2018 19:20:40 -0700 Subject: [PATCH] restore fix --- train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 2516ad55..f570b7fe 100644 --- a/train.py +++ b/train.py @@ -371,12 +371,18 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer.load_state_dict(checkpoint['optimizer']) - print("\n > Model restored from step %d\n" % checkpoint['step']) + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.cuda() + print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] start_epoch = 0 args.restore_step = checkpoint['step'] + optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) else: args.restore_step = 0 print("\n > Starting a new training")