diff --git a/train.py b/train.py index 1e186ba0..fd257d31 100644 --- a/train.py +++ b/train.py @@ -367,11 +367,11 @@ def main(args): criterion = L1LossMasked() criterion_st = nn.BCELoss() - partial_init_flag = False if args.restore_path: checkpoint = torch.load(args.restore_path) try: model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) except: print(" > Partial model initialization.") partial_init_flag = True @@ -385,7 +385,7 @@ def main(args): # 2. filter out different size layers pretrained_dict = { k: v - for k, v in checkpoint['model'].items() if v.numel() == model_dict[k].numel() + for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel() } # 3. overwrite entries in the existing state dict model_dict.update(pretrained_dict) @@ -396,8 +396,6 @@ def main(args): model = model.cuda() criterion.cuda() criterion_st.cuda() - if not partial_init_flag: - optimizer.load_state_dict(checkpoint['optimizer']) for group in optimizer.param_groups: group['lr'] = c.lr print(