mirror of https://github.com/coqui-ai/TTS.git
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import importlib
|
|
import re
|
|
|
|
|
|
def to_camel(text):
|
|
text = text.capitalize()
|
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
|
|
|
|
|
def setup_generator(c):
|
|
print(" > Generator Model: {}".format(c.generator_model))
|
|
MyModel = importlib.import_module("TTS.vocoder.tf.models." + c.generator_model.lower())
|
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
|
if c.generator_model in "melgan_generator":
|
|
model = MyModel(
|
|
in_channels=c.audio["num_mels"],
|
|
out_channels=1,
|
|
proj_kernel=7,
|
|
base_channels=512,
|
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
res_kernel=3,
|
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
)
|
|
if c.generator_model in "melgan_fb_generator":
|
|
pass
|
|
if c.generator_model in "multiband_melgan_generator":
|
|
model = MyModel(
|
|
in_channels=c.audio["num_mels"],
|
|
out_channels=4,
|
|
proj_kernel=7,
|
|
base_channels=384,
|
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
res_kernel=3,
|
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
)
|
|
return model
|