From 0372c5ce5d066667102193133aa46133037c1d09 Mon Sep 17 00:00:00 2001 From: erogol Date: Sat, 30 May 2020 18:08:26 +0200 Subject: [PATCH] set state dict using direct state_dict dict --- utils/generic_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 5a811907..5b135061 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -106,15 +106,15 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand -def set_init_dict(model_dict, checkpoint, c): +def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint['model'].items(): + for k, v in checkpoint_state.items(): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys pretrained_dict = { k: v - for k, v in checkpoint['model'].items() if k in model_dict + for k, v in checkpoint_state.items() if k in model_dict } # 2. filter out different size layers pretrained_dict = {