This commit is contained in:
Eren Gölge 2021-12-13 16:29:19 +00:00
parent 6d7199d559
commit 6274d5e438
1 changed files with 12 additions and 12 deletions

View File

@ -170,21 +170,21 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c):
if c.model_params["model_name"].lower() == "lstm":
def setup_speaker_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
c.model_params["input_dim"],
c.model_params["proj_dim"],
c.model_params["lstm_dim"],
c.model_params["num_lstm_layers"],
config.model_params["input_dim"],
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
)
elif c.model_params["model_name"].lower() == "resnet":
elif config.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(
input_dim=c.model_params["input_dim"],
proj_dim=c.model_params["proj_dim"],
log_input=c.model_params.get("log_input", False),
use_torch_spec=c.model_params.get("use_torch_spec", False),
audio_config=c.audio,
input_dim=config.model_params["input_dim"],
proj_dim=config.model_params["proj_dim"],
log_input=config.model_params.get("log_input", False),
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
return model