From a1fe8673716061e972dbda2206f64114e9c4c18f Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 13 Dec 2018 18:19:02 +0100 Subject: [PATCH] bug fix for partial model initialization, if model is not initialized, it is tried to init model partially with only matching layers in size --- train.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index b86ab04e..193306c3 100644 --- a/train.py +++ b/train.py @@ -360,17 +360,21 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) - model.load_state_dict(checkpoint['model']) - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint['model'].items() if k in model_dict - } - # 2. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - # 3. load the new state dict - model.load_state_dict(model_dict) + try: + model.load_state_dict(checkpoint['model']) + except: + print(" > Partial model initialization.") + model_dict = model.state_dict() + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda()