mirror of https://github.com/coqui-ai/TTS.git
load pwgan models
This commit is contained in:
parent
320bc29496
commit
69f525f17d
|
@ -67,14 +67,34 @@ def setup_generator(c):
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
|
if c.generator_model in 'parallel_wavegan_generator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=3,
|
||||||
|
num_res_blocks=c.generator_model_params['num_res_blocks'],
|
||||||
|
stacks=c.generator_model_params['stacks'],
|
||||||
|
res_channels=64,
|
||||||
|
gate_channels=128,
|
||||||
|
skip_channels=64,
|
||||||
|
aux_channels=c.audio['num_mels'],
|
||||||
|
aux_context_window=c['conv_pad'],
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
use_weight_norm=True,
|
||||||
|
upsample_conditional_features=True,
|
||||||
|
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def setup_discriminator(c):
|
def setup_discriminator(c):
|
||||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
if 'parallel_wavegan' in c.discriminator_model:
|
||||||
c.discriminator_model.lower())
|
MyModel = importlib.import_module('TTS.vocoder.models.parallel_wavegan_discriminator')
|
||||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model))
|
else:
|
||||||
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
|
c.discriminator_model.lower())
|
||||||
|
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||||
if c.discriminator_model in 'random_window_discriminator':
|
if c.discriminator_model in 'random_window_discriminator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
cond_channels=c.audio['num_mels'],
|
cond_channels=c.audio['num_mels'],
|
||||||
|
@ -95,6 +115,33 @@ def setup_discriminator(c):
|
||||||
max_channels=c.discriminator_model_params['max_channels'],
|
max_channels=c.discriminator_model_params['max_channels'],
|
||||||
downsample_factors=c.
|
downsample_factors=c.
|
||||||
discriminator_model_params['downsample_factors'])
|
discriminator_model_params['downsample_factors'])
|
||||||
|
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=3,
|
||||||
|
num_layers=c.discriminator_model_params['num_layers'],
|
||||||
|
stacks=c.discriminator_model_params['stacks'],
|
||||||
|
res_channels=64,
|
||||||
|
gate_channels=128,
|
||||||
|
skip_channels=64,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
nonlinear_activation="LeakyReLU",
|
||||||
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
|
)
|
||||||
|
if c.discriminator_model == 'parallel_wavegan_discriminator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=3,
|
||||||
|
num_layers=c.discriminator_model_params['num_layers'],
|
||||||
|
conv_channels=64,
|
||||||
|
dilation_factor=1,
|
||||||
|
nonlinear_activation="LeakyReLU",
|
||||||
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
|
bias=True
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue