mirror of https://github.com/coqui-ai/TTS.git
refactor(utils): remove duplicate set_partial_state_dict
This commit is contained in:
parent
1b6d3ebd33
commit
1f27f994a1
|
@ -5,10 +5,10 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.generic_utils import set_partial_state_dict
|
||||||
from trainer.io import load_fsspec
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.utils.generic_utils import set_init_dict
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ class BaseEncoder(nn.Module):
|
||||||
|
|
||||||
logger.info("Partial model initialization.")
|
logger.info("Partial model initialization.")
|
||||||
model_dict = self.state_dict()
|
model_dict = self.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
model_dict = set_partial_state_dict(model_dict, state["model"], config)
|
||||||
self.load_state_dict(model_dict)
|
self.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
|
|
|
@ -54,25 +54,6 @@ def get_import_path(obj: object) -> str:
|
||||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint_state, c):
|
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
|
||||||
for k, v in checkpoint_state.items():
|
|
||||||
if k not in model_dict:
|
|
||||||
logger.warning("Layer missing in the model finition %s", k)
|
|
||||||
# 1. filter out unnecessary keys
|
|
||||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
|
||||||
# 2. filter out different size layers
|
|
||||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
|
||||||
# 3. skip reinit layers
|
|
||||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
|
||||||
for reinit_layer_name in c.reinit_layers:
|
|
||||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
|
||||||
# 4. overwrite entries in the existing state dict
|
|
||||||
model_dict.update(pretrained_dict)
|
|
||||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
|
||||||
return model_dict
|
|
||||||
|
|
||||||
|
|
||||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||||
"""Format kwargs to hande auxilary inputs to models.
|
"""Format kwargs to hande auxilary inputs to models.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue