diff --git a/train.py b/train.py index 076b7042..cdbfa4bc 100644 --- a/train.py +++ b/train.py @@ -367,28 +367,37 @@ def main(args): criterion = L1LossMasked() criterion_st = nn.BCELoss() + partial_init_flag = False if args.restore_path: checkpoint = torch.load(args.restore_path) try: model.load_state_dict(checkpoint['model']) except: print(" > Partial model initialization.") + partial_init_flag = True model_dict = model.state_dict() # Partial initialization: if there is a mismatch with new and old layer, it is skipped. # 1. filter out unnecessary keys pretrained_dict = { k: v - for k, v in checkpoint['model'].items() if k in model_dict + for k, v in checkpoint['model'].items() if k in model_dict } - # 2. overwrite entries in the existing state dict + # 2. filter out different size layers + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if v.numel() == model_dict[k].numel() + } + # 3. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - # 3. load the new state dict + # 4. load the new state dict model.load_state_dict(model_dict) + print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict))) if use_cuda: model = model.cuda() criterion.cuda() criterion_st.cuda() - optimizer.load_state_dict(checkpoint['optimizer']) + if not partial_init_flag: + optimizer.load_state_dict(checkpoint['optimizer']) for group in optimizer.param_groups: group['lr'] = c.lr print(