set state dict using direct state_dict dict

This commit is contained in:
erogol 2020-05-30 18:08:26 +02:00
parent 3034797211
commit a5fc2f578f
1 changed files with 3 additions and 3 deletions

View File

@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None):
return seq_range_expand < seq_length_expand return seq_range_expand < seq_length_expand
def set_init_dict(model_dict, checkpoint, c): def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped. # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint['model'].items(): for k, v in checkpoint_state.items():
if k not in model_dict: if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k)) print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys # 1. filter out unnecessary keys
pretrained_dict = { pretrained_dict = {
k: v k: v
for k, v in checkpoint['model'].items() if k in model_dict for k, v in checkpoint_state.items() if k in model_dict
} }
# 2. filter out different size layers # 2. filter out different size layers
pretrained_dict = { pretrained_dict = {