From a5fc2f578f0d570d79a67fe8a0251a2160736a67 Mon Sep 17 00:00:00 2001 From: erogol Date: Sat, 30 May 2020 18:08:26 +0200 Subject: [PATCH] set state dict using direct state_dict dict --- utils/generic_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 4a7aaef9..f7c04859 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand -def set_init_dict(model_dict, checkpoint, c): +def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint['model'].items(): + for k, v in checkpoint_state.items(): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys pretrained_dict = { k: v - for k, v in checkpoint['model'].items() if k in model_dict + for k, v in checkpoint_state.items() if k in model_dict } # 2. filter out different size layers pretrained_dict = {