mirror of https://github.com/coqui-ai/TTS.git
print warning if a layer in ehckpoint is not defined in model definition
This commit is contained in:
parent
be2f2b8d62
commit
8a47b46195
|
@ -215,6 +215,9 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint, c):
|
def set_init_dict(model_dict, checkpoint, c):
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
|
for k, v in checkpoint['model'].items():
|
||||||
|
if k not in model_dict:
|
||||||
|
print(" | > Layer missing in the model definition: {}".format(k))
|
||||||
# 1. filter out unnecessary keys
|
# 1. filter out unnecessary keys
|
||||||
pretrained_dict = {
|
pretrained_dict = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -236,7 +239,7 @@ def set_init_dict(model_dict, checkpoint, c):
|
||||||
}
|
}
|
||||||
# 4. overwrite entries in the existing state dict
|
# 4. overwrite entries in the existing state dict
|
||||||
model_dict.update(pretrained_dict)
|
model_dict.update(pretrained_dict)
|
||||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue