mirror of https://github.com/coqui-ai/TTS.git
Implement `setup_model` for vocoder models
This commit is contained in:
parent
e949e7ad58
commit
d18198dff8
|
@ -0,0 +1,147 @@
|
|||
import importlib
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
|
||||
if config.model.lower() == "wavernn":
|
||||
MyModel = getattr(MyModel, "Wavernn")
|
||||
elif config.model.lower() == "gan":
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
raise ValueError(f"Model {config.model} not exist!")
|
||||
model = MyModel(config)
|
||||
return model
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
if c.generator_model.lower() in "hifigan_generator":
|
||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
||||
elif c.generator_model.lower() in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model in "melgan_fb_generator":
|
||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in "hifigan_discriminator":
|
||||
model = MyModel()
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
return model
|
Loading…
Reference in New Issue