set state dict using direct state_dict dict

This commit is contained in:
erogol 2020-05-30 18:08:26 +02:00
parent 3034797211
commit a5fc2f578f
1 changed files with 3 additions and 3 deletions

View File

@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None):
return seq_range_expand < seq_length_expand
def set_init_dict(model_dict, checkpoint, c):
def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint['model'].items():
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
for k, v in checkpoint['model'].items() if k in model_dict
for k, v in checkpoint_state.items() if k in model_dict
}
# 2. filter out different size layers
pretrained_dict = {