print warning if a layer in ehckpoint is not defined in model definition

This commit is contained in:
Eren Golge 2019-04-08 19:32:07 +02:00
parent be2f2b8d62
commit 8a47b46195
1 changed files with 4 additions and 1 deletions

View File

@ -215,6 +215,9 @@ def sequence_mask(sequence_length, max_len=None):
def set_init_dict(model_dict, checkpoint, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint['model'].items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
@ -236,7 +239,7 @@ def set_init_dict(model_dict, checkpoint, c):
}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
return model_dict