mirror of https://github.com/coqui-ai/TTS.git
print warning if a layer in ehckpoint is not defined in model definition
This commit is contained in:
parent
be2f2b8d62
commit
8a47b46195
|
@ -215,6 +215,9 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
|
||||
def set_init_dict(model_dict, checkpoint, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint['model'].items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
|
@ -236,7 +239,7 @@ def set_init_dict(model_dict, checkpoint, c):
|
|||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue