From 22d62aee5b577e6b006f174bd40b6dd719028784 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 11 Dec 2018 17:53:08 +0100 Subject: [PATCH] partial model initialization --- train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train.py b/train.py index 54d6140f..8fe07ded 100644 --- a/train.py +++ b/train.py @@ -401,6 +401,16 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda()