From 3818bd0c2308fe31e5a3a811e33fb14931267341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 13 Dec 2021 16:29:19 +0000 Subject: [PATCH] Fixup --- TTS/speaker_encoder/utils/generic_utils.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index c926e215..dab79f3c 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -170,21 +170,21 @@ def to_camel(text): return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) -def setup_model(c): - if c.model_params["model_name"].lower() == "lstm": +def setup_speaker_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": model = LSTMSpeakerEncoder( - c.model_params["input_dim"], - c.model_params["proj_dim"], - c.model_params["lstm_dim"], - c.model_params["num_lstm_layers"], + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], ) - elif c.model_params["model_name"].lower() == "resnet": + elif config.model_params["model_name"].lower() == "resnet": model = ResNetSpeakerEncoder( - input_dim=c.model_params["input_dim"], - proj_dim=c.model_params["proj_dim"], - log_input=c.model_params.get("log_input", False), - use_torch_spec=c.model_params.get("use_torch_spec", False), - audio_config=c.audio, + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, ) return model