diff --git a/train.py b/train.py index 3d4212bc..37b48cba 100644 --- a/train.py +++ b/train.py @@ -360,17 +360,21 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) - model.load_state_dict(checkpoint['model']) - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint['model'].items() if k in model_dict - } - # 2. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - # 3. load the new state dict - model.load_state_dict(model_dict) + try: + model.load_state_dict(checkpoint['model']) + except: + print(" > Partial model initialization.") + model_dict = model.state_dict() + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda()