mirror of https://github.com/coqui-ai/TTS.git
set state dict using direct state_dict dict
This commit is contained in:
parent
3034797211
commit
a5fc2f578f
|
@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint, c):
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint['model'].items():
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint['model'].items() if k in model_dict
|
||||
for k, v in checkpoint_state.items() if k in model_dict
|
||||
}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
|
|
Loading…
Reference in New Issue