refactor(utils): remove duplicate set_partial_state_dict

This commit is contained in:
Enno Hermann 2024-11-20 18:40:28 +01:00
parent 1b6d3ebd33
commit 1f27f994a1
2 changed files with 2 additions and 21 deletions

View File

@ -5,10 +5,10 @@ import torch
import torchaudio
from coqpit import Coqpit
from torch import nn
from trainer.generic_utils import set_partial_state_dict
from trainer.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
logger = logging.getLogger(__name__)
@ -130,7 +130,7 @@ class BaseEncoder(nn.Module):
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
model_dict = set_partial_state_dict(model_dict, state["model"], config)
self.load_state_dict(model_dict)
del model_dict

View File

@ -54,25 +54,6 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__])
def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
logger.warning("Layer missing in the model finition %s", k)
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers
if c.has("reinit_layers") and c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
"""Format kwargs to hande auxilary inputs to models.