mirror of https://github.com/coqui-ai/TTS.git
refactor(utils): remove duplicate set_partial_state_dict
This commit is contained in:
parent
1b6d3ebd33
commit
1f27f994a1
|
@ -5,10 +5,10 @@ import torch
|
|||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from trainer.generic_utils import set_partial_state_dict
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -130,7 +130,7 @@ class BaseEncoder(nn.Module):
|
|||
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
model_dict = set_partial_state_dict(model_dict, state["model"], config)
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
|
|
|
@ -54,25 +54,6 @@ def get_import_path(obj: object) -> str:
|
|||
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
logger.warning("Layer missing in the model finition %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
|
|
Loading…
Reference in New Issue