diff --git a/train.py b/train.py index 6a7f2ebd..97eec378 100644 --- a/train.py +++ b/train.py @@ -392,6 +392,8 @@ def main(args): # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore # optimizer.load_state_dict(checkpoint['optimizer']) + if len(c.reinit_layers) > 0: + raise RuntimeError model.load_state_dict(checkpoint['model']) except: print(" > Partial model initialization.")