mirror of https://github.com/coqui-ai/TTS.git
Fixup
This commit is contained in:
parent
79de38ca76
commit
3818bd0c23
|
@ -170,21 +170,21 @@ def to_camel(text):
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_model(c):
|
def setup_speaker_encoder_model(config: "Coqpit"):
|
||||||
if c.model_params["model_name"].lower() == "lstm":
|
if config.model_params["model_name"].lower() == "lstm":
|
||||||
model = LSTMSpeakerEncoder(
|
model = LSTMSpeakerEncoder(
|
||||||
c.model_params["input_dim"],
|
config.model_params["input_dim"],
|
||||||
c.model_params["proj_dim"],
|
config.model_params["proj_dim"],
|
||||||
c.model_params["lstm_dim"],
|
config.model_params["lstm_dim"],
|
||||||
c.model_params["num_lstm_layers"],
|
config.model_params["num_lstm_layers"],
|
||||||
)
|
)
|
||||||
elif c.model_params["model_name"].lower() == "resnet":
|
elif config.model_params["model_name"].lower() == "resnet":
|
||||||
model = ResNetSpeakerEncoder(
|
model = ResNetSpeakerEncoder(
|
||||||
input_dim=c.model_params["input_dim"],
|
input_dim=config.model_params["input_dim"],
|
||||||
proj_dim=c.model_params["proj_dim"],
|
proj_dim=config.model_params["proj_dim"],
|
||||||
log_input=c.model_params.get("log_input", False),
|
log_input=config.model_params.get("log_input", False),
|
||||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||||
audio_config=c.audio,
|
audio_config=config.audio,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue