mirror of https://github.com/coqui-ai/TTS.git
add wavegrad to vocoder generators
This commit is contained in:
parent
e723b99888
commit
ac57eea928
|
@ -70,7 +70,7 @@ def setup_generator(c):
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
c.generator_model.lower())
|
c.generator_model.lower())
|
||||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||||
if c.generator_model in 'melgan_generator':
|
if c.generator_model.lower() in 'melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
|
@ -81,7 +81,7 @@ def setup_generator(c):
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
if c.generator_model in 'melgan_fb_generator':
|
if c.generator_model in 'melgan_fb_generator':
|
||||||
pass
|
pass
|
||||||
if c.generator_model in 'multiband_melgan_generator':
|
if c.generator_model.lower() in 'multiband_melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=4,
|
out_channels=4,
|
||||||
|
@ -90,7 +90,7 @@ def setup_generator(c):
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
if c.generator_model in 'fullband_melgan_generator':
|
if c.generator_model.lower() in 'fullband_melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
|
@ -99,7 +99,7 @@ def setup_generator(c):
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
if c.generator_model in 'parallel_wavegan_generator':
|
if c.generator_model.lower() in 'parallel_wavegan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
|
@ -114,6 +114,16 @@ def setup_generator(c):
|
||||||
bias=True,
|
bias=True,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'])
|
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||||
|
if c.generator_model.lower() in 'wavegrad':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=c['audio']['num_mels'],
|
||||||
|
out_channels=1,
|
||||||
|
x_conv_channels=c['model_params']['x_conv_channels'],
|
||||||
|
c_conv_channels=c['model_params']['c_conv_channels'],
|
||||||
|
dblock_out_channels=c['model_params']['dblock_out_channels'],
|
||||||
|
ublock_out_channels=c['model_params']['ublock_out_channels'],
|
||||||
|
upsample_factors=c['model_params']['upsample_factors'],
|
||||||
|
upsample_dilations=c['model_params']['upsample_dilations'])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue