diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index c926e215..dab79f3c 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -170,21 +170,21 @@ def to_camel(text): return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) -def setup_model(c): - if c.model_params["model_name"].lower() == "lstm": +def setup_speaker_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": model = LSTMSpeakerEncoder( - c.model_params["input_dim"], - c.model_params["proj_dim"], - c.model_params["lstm_dim"], - c.model_params["num_lstm_layers"], + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], ) - elif c.model_params["model_name"].lower() == "resnet": + elif config.model_params["model_name"].lower() == "resnet": model = ResNetSpeakerEncoder( - input_dim=c.model_params["input_dim"], - proj_dim=c.model_params["proj_dim"], - log_input=c.model_params.get("log_input", False), - use_torch_spec=c.model_params.get("use_torch_spec", False), - audio_config=c.audio, + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, ) return model