add wavegrad to vocoder generators

This commit is contained in:
erogol 2020-10-16 16:34:07 +02:00
parent e723b99888
commit ac57eea928
1 changed files with 14 additions and 4 deletions

View File

@ -70,7 +70,7 @@ def setup_generator(c):
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model in 'melgan_generator':
if c.generator_model.lower() in 'melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
@ -81,7 +81,7 @@ def setup_generator(c):
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'melgan_fb_generator':
pass
if c.generator_model in 'multiband_melgan_generator':
if c.generator_model.lower() in 'multiband_melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=4,
@ -90,7 +90,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'fullband_melgan_generator':
if c.generator_model.lower() in 'fullband_melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
@ -99,7 +99,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'parallel_wavegan_generator':
if c.generator_model.lower() in 'parallel_wavegan_generator':
model = MyModel(
in_channels=1,
out_channels=1,
@ -114,6 +114,16 @@ def setup_generator(c):
bias=True,
use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors'])
if c.generator_model.lower() in 'wavegrad':
model = MyModel(
in_channels=c['audio']['num_mels'],
out_channels=1,
x_conv_channels=c['model_params']['x_conv_channels'],
c_conv_channels=c['model_params']['c_conv_channels'],
dblock_out_channels=c['model_params']['dblock_out_channels'],
ublock_out_channels=c['model_params']['ublock_out_channels'],
upsample_factors=c['model_params']['upsample_factors'],
upsample_dilations=c['model_params']['upsample_dilations'])
return model