diff --git a/train.py b/train.py index aa8e92ef..fd0c03f7 100644 --- a/train.py +++ b/train.py @@ -361,7 +361,7 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint['optimizer'].cuda()) print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss']