mirror of https://github.com/coqui-ai/TTS.git
set state dict using direct state_dict dict
This commit is contained in:
parent
3034797211
commit
a5fc2f578f
|
@ -107,15 +107,15 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint, c):
|
def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
for k, v in checkpoint['model'].items():
|
for k, v in checkpoint_state.items():
|
||||||
if k not in model_dict:
|
if k not in model_dict:
|
||||||
print(" | > Layer missing in the model definition: {}".format(k))
|
print(" | > Layer missing in the model definition: {}".format(k))
|
||||||
# 1. filter out unnecessary keys
|
# 1. filter out unnecessary keys
|
||||||
pretrained_dict = {
|
pretrained_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in checkpoint['model'].items() if k in model_dict
|
for k, v in checkpoint_state.items() if k in model_dict
|
||||||
}
|
}
|
||||||
# 2. filter out different size layers
|
# 2. filter out different size layers
|
||||||
pretrained_dict = {
|
pretrained_dict = {
|
||||||
|
|
Loading…
Reference in New Issue