diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 4a7aaef9..f7c04859 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand -def set_init_dict(model_dict, checkpoint, c): +def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint['model'].items(): + for k, v in checkpoint_state.items(): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys pretrained_dict = { k: v - for k, v in checkpoint['model'].items() if k in model_dict + for k, v in checkpoint_state.items() if k in model_dict } # 2. filter out different size layers pretrained_dict = {