diff --git a/train.py b/train.py index 0c3cff03..8aa0567d 100644 --- a/train.py +++ b/train.py @@ -569,7 +569,7 @@ def main(args): # pylint: disable=redefined-outer-name pos_weight=torch.tensor(10)) if c.stopnet else None if args.restore_path: - checkpoint = torch.load(args.restore_path) + checkpoint = torch.load(args.restore_path, map_location='cpu') try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore