mirror of https://github.com/coqui-ai/TTS.git
Fixup
This commit is contained in:
parent
6d7199d559
commit
6274d5e438
|
@ -170,21 +170,21 @@ def to_camel(text):
|
|||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(c):
|
||||
if c.model_params["model_name"].lower() == "lstm":
|
||||
def setup_speaker_encoder_model(config: "Coqpit"):
|
||||
if config.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
c.model_params["input_dim"],
|
||||
c.model_params["proj_dim"],
|
||||
c.model_params["lstm_dim"],
|
||||
c.model_params["num_lstm_layers"],
|
||||
config.model_params["input_dim"],
|
||||
config.model_params["proj_dim"],
|
||||
config.model_params["lstm_dim"],
|
||||
config.model_params["num_lstm_layers"],
|
||||
)
|
||||
elif c.model_params["model_name"].lower() == "resnet":
|
||||
elif config.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(
|
||||
input_dim=c.model_params["input_dim"],
|
||||
proj_dim=c.model_params["proj_dim"],
|
||||
log_input=c.model_params.get("log_input", False),
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
audio_config=c.audio,
|
||||
input_dim=config.model_params["input_dim"],
|
||||
proj_dim=config.model_params["proj_dim"],
|
||||
log_input=config.model_params.get("log_input", False),
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue